-
Notifications
You must be signed in to change notification settings - Fork 9
Fix: Correctly use context model for moderation #26
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
db06b9f
1a1c15c
b530cac
de60e8e
1bab19a
bf8a028
fa86851
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -66,6 +66,85 @@ describe('moderation guardrail', () => { | |
| expect(result.tripwireTriggered).toBe(false); | ||
| expect(result.info?.error).toBe('Moderation API call failed'); | ||
| }); | ||
|
|
||
| it('uses context client when available', async () => { | ||
| // Track whether context client was used | ||
| let contextClientUsed = false; | ||
| const contextCreateMock = vi.fn().mockImplementation(async () => { | ||
| contextClientUsed = true; | ||
| return { | ||
| results: [ | ||
| { | ||
| categories: { | ||
| [Category.HATE]: false, | ||
| [Category.VIOLENCE]: false, | ||
| }, | ||
| }, | ||
| ], | ||
| }; | ||
| }); | ||
|
|
||
| // Create a context with a guardrailLlm client | ||
| // We need to import OpenAI to create a proper instance | ||
| const OpenAI = (await import('openai')).default; | ||
| const contextClient = new OpenAI({ apiKey: 'test-context-key' }); | ||
| contextClient.moderations = { | ||
| create: contextCreateMock, | ||
| } as unknown as typeof contextClient.moderations; | ||
|
|
||
| const ctx = { guardrailLlm: contextClient }; | ||
| const cfg = ModerationConfig.parse({ categories: [Category.HATE] }); | ||
| const result = await moderationCheck(ctx, 'test text', cfg); | ||
|
|
||
| // Verify the context client was used | ||
| expect(contextClientUsed).toBe(true); | ||
| expect(contextCreateMock).toHaveBeenCalledWith({ | ||
| model: 'omni-moderation-latest', | ||
| input: 'test text', | ||
| }); | ||
| expect(result.tripwireTriggered).toBe(false); | ||
| }); | ||
|
|
||
| it('falls back to default client for third-party providers', async () => { | ||
| // Track whether fallback client was used | ||
| let fallbackUsed = false; | ||
|
|
||
| // The default mock from vi.mock will be used for the fallback | ||
| createMock.mockImplementation(async () => { | ||
| fallbackUsed = true; | ||
|
Comment on lines
+115
to
+116
|
||
| return { | ||
| results: [ | ||
| { | ||
| categories: { | ||
| [Category.HATE]: false, | ||
| }, | ||
| }, | ||
| ], | ||
| }; | ||
| }); | ||
|
|
||
| // Create a context client that simulates a third-party provider | ||
| // When moderation is called, it should raise a 404 error | ||
| const contextCreateMock = vi.fn().mockRejectedValue({ | ||
| status: 404, | ||
| message: '404 page not found', | ||
| }); | ||
|
|
||
| const OpenAI = (await import('openai')).default; | ||
| const thirdPartyClient = new OpenAI({ apiKey: 'third-party-key', baseURL: 'https://localhost:8080/v1' }); | ||
| thirdPartyClient.moderations = { | ||
| create: contextCreateMock, | ||
| } as unknown as typeof thirdPartyClient.moderations; | ||
|
|
||
| const ctx = { guardrailLlm: thirdPartyClient }; | ||
| const cfg = ModerationConfig.parse({ categories: [Category.HATE] }); | ||
| const result = await moderationCheck(ctx, 'test text', cfg); | ||
|
|
||
| // Verify the fallback client was used (not the third-party one) | ||
| expect(contextCreateMock).toHaveBeenCalled(); | ||
| expect(fallbackUsed).toBe(true); | ||
| expect(result.tripwireTriggered).toBe(false); | ||
| }); | ||
| }); | ||
|
|
||
| describe('secret key guardrail', () => { | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -78,6 +78,33 @@ export const ModerationContext = z.object({ | |||||||||
|
|
||||||||||
| export type ModerationContext = z.infer<typeof ModerationContext>; | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * Check if an error is a 404 Not Found error from the OpenAI API. | ||||||||||
| * | ||||||||||
| * @param error The error to check | ||||||||||
| * @returns True if the error is a 404 error | ||||||||||
| */ | ||||||||||
| function isNotFoundError(error: unknown): boolean { | ||||||||||
| return !!(error && typeof error === 'object' && 'status' in error && error.status === 404); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * Call the OpenAI moderation API. | ||||||||||
| * | ||||||||||
| * @param client The OpenAI client to use | ||||||||||
| * @param data The text to analyze | ||||||||||
| * @returns The moderation API response | ||||||||||
| */ | ||||||||||
| function callModerationAPI( | ||||||||||
| client: OpenAI, | ||||||||||
| data: string | ||||||||||
| ): ReturnType<OpenAI['moderations']['create']> { | ||||||||||
| return client.moderations.create({ | ||||||||||
| model: 'omni-moderation-latest', | ||||||||||
| input: data, | ||||||||||
| }); | ||||||||||
| } | ||||||||||
|
|
||||||||||
| /** | ||||||||||
| * Guardrail check_fn to flag disallowed content categories using OpenAI moderation API. | ||||||||||
| * | ||||||||||
|
|
@@ -102,39 +129,54 @@ export const moderationCheck: CheckFn<ModerationContext, string, ModerationConfi | |||||||||
| const configObj = actualConfig as Record<string, unknown>; | ||||||||||
| const categories = (configObj.categories as string[]) || Object.values(Category); | ||||||||||
|
|
||||||||||
| // Reuse provided client only if it targets the official OpenAI API. | ||||||||||
| const reuseClientIfOpenAI = (context: unknown): OpenAI | null => { | ||||||||||
| try { | ||||||||||
| const contextObj = context as Record<string, unknown>; | ||||||||||
| const candidate = contextObj?.guardrailLlm; | ||||||||||
| if (!candidate || typeof candidate !== 'object') return null; | ||||||||||
| if (!(candidate instanceof OpenAI)) return null; | ||||||||||
|
|
||||||||||
| const candidateObj = candidate as unknown as Record<string, unknown>; | ||||||||||
| const baseURL: string | undefined = | ||||||||||
| (candidateObj.baseURL as string) ?? | ||||||||||
| ((candidateObj._client as Record<string, unknown>)?.baseURL as string) ?? | ||||||||||
| (candidateObj._baseURL as string); | ||||||||||
|
|
||||||||||
| if ( | ||||||||||
| baseURL === undefined || | ||||||||||
| (typeof baseURL === 'string' && baseURL.includes('api.openai.com')) | ||||||||||
| ) { | ||||||||||
| return candidate as OpenAI; | ||||||||||
| } | ||||||||||
| return null; | ||||||||||
| } catch { | ||||||||||
| return null; | ||||||||||
| // Get client from context if available | ||||||||||
| let client: OpenAI | null = null; | ||||||||||
| if (ctx) { | ||||||||||
| const contextObj = ctx as Record<string, unknown>; | ||||||||||
| const candidate = contextObj.guardrailLlm; | ||||||||||
| if (candidate && candidate instanceof OpenAI) { | ||||||||||
| client = candidate; | ||||||||||
| } | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| const client = reuseClientIfOpenAI(ctx) ?? new OpenAI(); | ||||||||||
| } | ||||||||||
|
Comment on lines
142
to
150
|
||||||||||
|
|
||||||||||
| try { | ||||||||||
| const resp = await client.moderations.create({ | ||||||||||
| model: 'omni-moderation-latest', | ||||||||||
| input: data, | ||||||||||
| }); | ||||||||||
| // Try the context client first, fall back if moderation endpoint doesn't exist | ||||||||||
| let resp: Awaited<ReturnType<typeof callModerationAPI>>; | ||||||||||
| if (client !== null) { | ||||||||||
| try { | ||||||||||
| resp = await callModerationAPI(client, data); | ||||||||||
| } catch (error) { | ||||||||||
| // Moderation endpoint doesn't exist on this provider (e.g., third-party) | ||||||||||
| // Fall back to the OpenAI client | ||||||||||
| if (isNotFoundError(error)) { | ||||||||||
| try { | ||||||||||
| resp = await callModerationAPI(new OpenAI(), data); | ||||||||||
|
Comment on lines
161
to
163
|
||||||||||
| } catch (fallbackError) { | ||||||||||
| // If fallback fails, provide a helpful error message | ||||||||||
| const errorMessage = fallbackError instanceof Error | ||||||||||
| ? fallbackError.message | ||||||||||
| : String(fallbackError); | ||||||||||
|
Comment on lines
166
to
168
|
||||||||||
|
|
||||||||||
| // Check if it's an API key error | ||||||||||
| if (errorMessage.includes('api_key') || errorMessage.includes('OPENAI_API_KEY')) { | ||||||||||
|
Comment on lines
170
to
171
|
||||||||||
| // Check if it's an API key error | |
| if (errorMessage.includes('api_key') || errorMessage.includes('OPENAI_API_KEY')) { | |
| // Check if it's an API key error (HTTP 401 Unauthorized) | |
| if (typeof fallbackError === 'object' && fallbackError !== null && 'status' in fallbackError && fallbackError.status === 401) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm okay with this. It is still throwing an appropriate error regardless, this check just adds more useful information for the user.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The
contextClientUsedflag is unnecessary. You can check if the mock was called usingexpect(contextCreateMock).toHaveBeenCalled()instead of tracking this manually.