Skip to content

Commit 50df54b

Browse files
authored
Fix custom and off topic to use llm_base (#31)
1 parent 0f91e41 commit 50df54b

File tree

4 files changed

+508
-417
lines changed

4 files changed

+508
-417
lines changed

src/__tests__/unit/checks/topical-alignment.test.ts

Lines changed: 203 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,20 @@
22
* Tests for the topical alignment guardrail.
33
*/
44

5-
import { describe, it, expect, vi, afterEach } from 'vitest';
5+
import { describe, it, expect, vi, beforeEach } from 'vitest';
6+
import { GuardrailLLMContext } from '../../../types';
67

7-
const buildFullPromptMock = vi.fn((prompt: string) => `FULL:${prompt}`);
8+
const createLLMCheckFnMock = vi.fn(() => 'mocked-guardrail');
89
const registerMock = vi.fn();
910

1011
vi.mock('../../../checks/llm-base', () => ({
11-
buildFullPrompt: buildFullPromptMock,
12+
createLLMCheckFn: createLLMCheckFnMock,
13+
LLMConfig: {
14+
omit: vi.fn(() => ({
15+
extend: vi.fn(() => ({})),
16+
})),
17+
},
18+
LLMOutput: {},
1219
}));
1320

1421
vi.mock('../../../registry', () => ({
@@ -17,9 +24,23 @@ vi.mock('../../../registry', () => ({
1724
},
1825
}));
1926

20-
describe('topicalAlignmentCheck', () => {
21-
afterEach(() => {
22-
buildFullPromptMock.mockClear();
27+
describe('topicalAlignment guardrail', () => {
28+
beforeEach(() => {
29+
registerMock.mockClear();
30+
createLLMCheckFnMock.mockClear();
31+
});
32+
33+
it('is created via createLLMCheckFn', async () => {
34+
const { topicalAlignment } = await import('../../../checks/topical-alignment');
35+
36+
expect(topicalAlignment).toBe('mocked-guardrail');
37+
expect(createLLMCheckFnMock).toHaveBeenCalled();
38+
});
39+
});
40+
41+
describe('topicalAlignment integration tests', () => {
42+
beforeEach(() => {
43+
vi.resetModules();
2344
});
2445

2546
interface TopicalAlignmentConfig {
@@ -28,12 +49,6 @@ describe('topicalAlignmentCheck', () => {
2849
system_prompt_details: string;
2950
}
3051

31-
const config: TopicalAlignmentConfig = {
32-
model: 'gpt-topic',
33-
confidence_threshold: 0.6,
34-
system_prompt_details: 'Stay on topic about finance.',
35-
};
36-
3752
interface MockLLMResponse {
3853
choices: Array<{
3954
message: {
@@ -42,8 +57,13 @@ describe('topicalAlignmentCheck', () => {
4257
}>;
4358
}
4459

45-
const makeCtx = (response: MockLLMResponse) => {
46-
const create = vi.fn().mockResolvedValue(response);
60+
const makeCtx = (response: MockLLMResponse, capturedParams?: { value?: unknown }) => {
61+
const create = vi.fn().mockImplementation((params) => {
62+
if (capturedParams) {
63+
capturedParams.value = params;
64+
}
65+
return Promise.resolve(response);
66+
});
4767
return {
4868
ctx: {
4969
guardrailLlm: {
@@ -52,83 +72,208 @@ describe('topicalAlignmentCheck', () => {
5272
create,
5373
},
5474
},
75+
baseURL: 'https://api.openai.com/v1',
5576
},
56-
},
77+
} as GuardrailLLMContext,
5778
create,
5879
};
5980
};
6081

61-
it('triggers when LLM flags off-topic content above threshold', async () => {
62-
const { topicalAlignmentCheck } = await import('../../../checks/topical-alignment');
63-
const { ctx, create } = makeCtx({
64-
choices: [
65-
{
66-
message: {
67-
content: JSON.stringify({ flagged: true, confidence: 0.8 }),
82+
it('triggers when LLM flags off-topic content above threshold with gpt-4', async () => {
83+
vi.doUnmock('../../../checks/llm-base');
84+
vi.doUnmock('../../../checks/topical-alignment');
85+
86+
const { topicalAlignment } = await import('../../../checks/topical-alignment');
87+
const capturedParams: { value?: unknown } = {};
88+
const { ctx, create } = makeCtx(
89+
{
90+
choices: [
91+
{
92+
message: {
93+
content: JSON.stringify({ flagged: true, confidence: 0.8 }),
94+
},
6895
},
69-
},
70-
],
71-
});
96+
],
97+
},
98+
capturedParams
99+
);
72100

73-
const result = await topicalAlignmentCheck(ctx, 'Discussing sports', config);
101+
const config: TopicalAlignmentConfig = {
102+
model: 'gpt-4',
103+
confidence_threshold: 0.7,
104+
system_prompt_details: 'Stay on topic about finance.',
105+
};
74106

75-
expect(buildFullPromptMock).toHaveBeenCalled();
76-
expect(create).toHaveBeenCalledWith({
77-
messages: [
78-
{ role: 'system', content: expect.stringContaining('Stay on topic about finance.') },
79-
{ role: 'user', content: 'Discussing sports' },
80-
],
81-
model: 'gpt-topic',
82-
temperature: 0.0,
83-
response_format: { type: 'json_object' },
84-
});
107+
const result = await topicalAlignment(ctx, 'Discussing sports', config);
108+
109+
expect(create).toHaveBeenCalled();
110+
const params = capturedParams.value as Record<string, unknown>;
111+
expect(params.model).toBe('gpt-4');
112+
expect(params.temperature).toBe(0.0); // gpt-4 uses temperature 0
113+
expect(params.response_format).toEqual({ type: 'json_object' });
85114
expect(result.tripwireTriggered).toBe(true);
86115
expect(result.info?.flagged).toBe(true);
87116
expect(result.info?.confidence).toBe(0.8);
88117
});
89118

90-
it('returns failure info when no content is returned', async () => {
91-
const { topicalAlignmentCheck } = await import('../../../checks/topical-alignment');
119+
it('uses temperature 1.0 for gpt-5 models (which do not support temperature 0)', async () => {
120+
vi.doUnmock('../../../checks/llm-base');
121+
vi.doUnmock('../../../checks/topical-alignment');
122+
123+
const { topicalAlignment } = await import('../../../checks/topical-alignment');
124+
const capturedParams: { value?: unknown } = {};
125+
const { ctx, create } = makeCtx(
126+
{
127+
choices: [
128+
{
129+
message: {
130+
content: JSON.stringify({ flagged: false, confidence: 0.2 }),
131+
},
132+
},
133+
],
134+
},
135+
capturedParams
136+
);
137+
138+
const config: TopicalAlignmentConfig = {
139+
model: 'gpt-5',
140+
confidence_threshold: 0.7,
141+
system_prompt_details: 'Stay on topic about technology.',
142+
};
143+
144+
const result = await topicalAlignment(ctx, 'Discussing AI and ML', config);
145+
146+
expect(create).toHaveBeenCalled();
147+
const params = capturedParams.value as Record<string, unknown>;
148+
expect(params.model).toBe('gpt-5');
149+
expect(params.temperature).toBe(1.0); // gpt-5 uses temperature 1.0, not 0
150+
expect(params.response_format).toEqual({ type: 'json_object' });
151+
expect(result.tripwireTriggered).toBe(false);
152+
});
153+
154+
it('works with gpt-4o model', async () => {
155+
vi.doUnmock('../../../checks/llm-base');
156+
vi.doUnmock('../../../checks/topical-alignment');
157+
158+
const { topicalAlignment } = await import('../../../checks/topical-alignment');
159+
const capturedParams: { value?: unknown } = {};
160+
const { ctx, create } = makeCtx(
161+
{
162+
choices: [
163+
{
164+
message: {
165+
content: JSON.stringify({ flagged: true, confidence: 0.9 }),
166+
},
167+
},
168+
],
169+
},
170+
capturedParams
171+
);
172+
173+
const config: TopicalAlignmentConfig = {
174+
model: 'gpt-4o',
175+
confidence_threshold: 0.8,
176+
system_prompt_details: 'Stay on topic about healthcare.',
177+
};
178+
179+
const result = await topicalAlignment(ctx, 'Talking about cars', config);
180+
181+
expect(create).toHaveBeenCalled();
182+
const params = capturedParams.value as Record<string, unknown>;
183+
expect(params.model).toBe('gpt-4o');
184+
expect(params.temperature).toBe(0.0); // gpt-4o uses temperature 0
185+
expect(result.tripwireTriggered).toBe(true);
186+
});
187+
188+
it('works with gpt-3.5-turbo model', async () => {
189+
vi.doUnmock('../../../checks/llm-base');
190+
vi.doUnmock('../../../checks/topical-alignment');
191+
192+
const { topicalAlignment } = await import('../../../checks/topical-alignment');
193+
const capturedParams: { value?: unknown } = {};
194+
const { ctx, create } = makeCtx(
195+
{
196+
choices: [
197+
{
198+
message: {
199+
content: JSON.stringify({ flagged: false, confidence: 0.3 }),
200+
},
201+
},
202+
],
203+
},
204+
capturedParams
205+
);
206+
207+
const config: TopicalAlignmentConfig = {
208+
model: 'gpt-3.5-turbo',
209+
confidence_threshold: 0.7,
210+
system_prompt_details: 'Stay on topic about education.',
211+
};
212+
213+
const result = await topicalAlignment(ctx, 'Discussing teaching methods', config);
214+
215+
expect(create).toHaveBeenCalled();
216+
const params = capturedParams.value as Record<string, unknown>;
217+
expect(params.model).toBe('gpt-3.5-turbo');
218+
expect(params.temperature).toBe(0.0);
219+
expect(result.tripwireTriggered).toBe(false);
220+
});
221+
222+
it('does not trigger when confidence is below threshold', async () => {
223+
vi.doUnmock('../../../checks/llm-base');
224+
vi.doUnmock('../../../checks/topical-alignment');
225+
226+
const { topicalAlignment } = await import('../../../checks/topical-alignment');
92227
const { ctx } = makeCtx({
93-
choices: [{ message: { content: '' } }],
228+
choices: [
229+
{
230+
message: {
231+
content: JSON.stringify({ flagged: true, confidence: 0.5 }),
232+
},
233+
},
234+
],
94235
});
95236

96-
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
97-
const result = await topicalAlignmentCheck(ctx, 'Hi', config);
237+
const config: TopicalAlignmentConfig = {
238+
model: 'gpt-4',
239+
confidence_threshold: 0.7,
240+
system_prompt_details: 'Stay on topic about finance.',
241+
};
98242

99-
consoleSpy.mockRestore();
243+
const result = await topicalAlignment(ctx, 'Maybe off topic', config);
100244

101245
expect(result.tripwireTriggered).toBe(false);
102-
expect(result.info?.error).toBeDefined();
246+
expect(result.info?.flagged).toBe(true);
247+
expect(result.info?.confidence).toBe(0.5);
103248
});
104249

105-
it('handles unexpected errors gracefully', async () => {
250+
it('handles execution failures gracefully', async () => {
251+
vi.doUnmock('../../../checks/llm-base');
252+
vi.doUnmock('../../../checks/topical-alignment');
253+
106254
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
107-
const { topicalAlignmentCheck } = await import('../../../checks/topical-alignment');
255+
const { topicalAlignment } = await import('../../../checks/topical-alignment');
108256
const ctx = {
109257
guardrailLlm: {
110258
chat: {
111259
completions: {
112-
create: vi.fn().mockRejectedValue(new Error('timeout')),
260+
create: vi.fn().mockRejectedValue(new Error('API timeout')),
113261
},
114262
},
263+
baseURL: 'https://api.openai.com/v1',
115264
},
265+
} as GuardrailLLMContext;
266+
267+
const config: TopicalAlignmentConfig = {
268+
model: 'gpt-4',
269+
confidence_threshold: 0.7,
270+
system_prompt_details: 'Stay on topic about finance.',
116271
};
117272

118-
interface MockContext {
119-
guardrailLlm: {
120-
chat: {
121-
completions: {
122-
create: ReturnType<typeof vi.fn>;
123-
};
124-
};
125-
};
126-
}
127-
128-
const result = await topicalAlignmentCheck(ctx as MockContext, 'Test', config);
273+
const result = await topicalAlignment(ctx, 'Test text', config);
129274

130275
expect(result.tripwireTriggered).toBe(false);
131-
expect(result.info?.error).toContain('timeout');
276+
expect(result.executionFailed).toBe(true);
132277
consoleSpy.mockRestore();
133278
});
134279
});

0 commit comments

Comments
 (0)