Skip to content

Commit 612a5c4

Browse files
committed
Add doc strings. Extract shared logic
1 parent 154d545 commit 612a5c4

File tree

2 files changed

+145
-97
lines changed

2 files changed

+145
-97
lines changed

src/__tests__/unit/agents.test.ts

Lines changed: 81 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -358,98 +358,98 @@ describe('GuardrailAgent', () => {
358358

359359
// Test the guardrail function
360360
const guardrailFunction = agent.inputGuardrails[0];
361-
const result = await guardrailFunction.execute('test input');
361+
const result = await guardrailFunction.execute('test input');
362362

363-
expect(result).toHaveProperty('outputInfo');
364-
expect(result).toHaveProperty('tripwireTriggered');
365-
expect(typeof result.tripwireTriggered).toBe('boolean');
366-
});
363+
expect(result).toHaveProperty('outputInfo');
364+
expect(result).toHaveProperty('tripwireTriggered');
365+
expect(typeof result.tripwireTriggered).toBe('boolean');
366+
});
367367

368-
it('passes the latest user message text to guardrails for conversation inputs', async () => {
369-
process.env.OPENAI_API_KEY = 'test';
370-
const config = {
371-
version: 1,
372-
input: {
368+
it('passes the latest user message text to guardrails for conversation inputs', async () => {
369+
process.env.OPENAI_API_KEY = 'test';
370+
const config = {
373371
version: 1,
374-
guardrails: [{ name: 'Moderation', config: {} }],
375-
},
376-
};
372+
input: {
373+
version: 1,
374+
guardrails: [{ name: 'Moderation', config: {} }],
375+
},
376+
};
377377

378-
const { instantiateGuardrails } = await import('../../runtime');
379-
const runSpy = vi.fn().mockResolvedValue({
380-
tripwireTriggered: false,
381-
info: { guardrail_name: 'Moderation' },
382-
});
378+
const { instantiateGuardrails } = await import('../../runtime');
379+
const runSpy = vi.fn().mockResolvedValue({
380+
tripwireTriggered: false,
381+
info: { guardrail_name: 'Moderation' },
382+
});
383383

384-
vi.mocked(instantiateGuardrails).mockImplementationOnce(() =>
385-
Promise.resolve([
384+
vi.mocked(instantiateGuardrails).mockImplementationOnce(() =>
385+
Promise.resolve([
386+
{
387+
definition: {
388+
name: 'Moderation',
389+
description: 'Moderation guardrail',
390+
mediaType: 'text/plain',
391+
configSchema: z.object({}),
392+
checkFn: vi.fn(),
393+
metadata: {},
394+
ctxRequirements: z.object({}),
395+
schema: () => ({}),
396+
instantiate: vi.fn(),
397+
},
398+
config: {},
399+
run: runSpy,
400+
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T>
401+
? T extends readonly (infer U)[]
402+
? U
403+
: never
404+
: never,
405+
])
406+
);
407+
408+
const agent = (await GuardrailAgent.create(
409+
config,
410+
'Conversation Agent',
411+
'Handle multi-turn conversations'
412+
)) as MockAgent;
413+
414+
const guardrail = agent.inputGuardrails[0] as unknown as {
415+
execute: (args: { input: unknown; context?: unknown }) => Promise<{
416+
outputInfo: Record<string, unknown>;
417+
tripwireTriggered: boolean;
418+
}>;
419+
};
420+
421+
const conversation = [
422+
{ role: 'system', content: 'You are helpful.' },
423+
{ role: 'user', content: [{ type: 'input_text', text: 'First question?' }] },
424+
{ role: 'assistant', content: [{ type: 'output_text', text: 'An answer.' }] },
386425
{
387-
definition: {
388-
name: 'Moderation',
389-
description: 'Moderation guardrail',
390-
mediaType: 'text/plain',
391-
configSchema: z.object({}),
392-
checkFn: vi.fn(),
393-
metadata: {},
394-
ctxRequirements: z.object({}),
395-
schema: () => ({}),
396-
instantiate: vi.fn(),
397-
},
398-
config: {},
399-
run: runSpy,
400-
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T>
401-
? T extends readonly (infer U)[]
402-
? U
403-
: never
404-
: never,
405-
])
406-
);
407-
408-
const agent = (await GuardrailAgent.create(
409-
config,
410-
'Conversation Agent',
411-
'Handle multi-turn conversations'
412-
)) as MockAgent;
413-
414-
const guardrail = agent.inputGuardrails[0] as unknown as {
415-
execute: (args: { input: unknown; context?: unknown }) => Promise<{
416-
outputInfo: Record<string, unknown>;
417-
tripwireTriggered: boolean;
418-
}>;
419-
};
420-
421-
const conversation = [
422-
{ role: 'system', content: 'You are helpful.' },
423-
{ role: 'user', content: [{ type: 'input_text', text: 'First question?' }] },
424-
{ role: 'assistant', content: [{ type: 'output_text', text: 'An answer.' }] },
425-
{
426-
role: 'user',
427-
content: [
428-
{ type: 'input_text', text: 'Latest user message' },
429-
{ type: 'input_text', text: 'with additional context.' },
430-
],
431-
},
432-
];
426+
role: 'user',
427+
content: [
428+
{ type: 'input_text', text: 'Latest user message' },
429+
{ type: 'input_text', text: 'with additional context.' },
430+
],
431+
},
432+
];
433433

434-
const result = await guardrail.execute({ input: conversation, context: {} });
434+
const result = await guardrail.execute({ input: conversation, context: {} });
435435

436-
expect(runSpy).toHaveBeenCalledTimes(1);
437-
const [ctxArgRaw, dataArg] = runSpy.mock.calls[0] as [unknown, string];
438-
const ctxArg = ctxArgRaw as { getConversationHistory?: () => unknown[] };
439-
expect(dataArg).toBe('Latest user message with additional context.');
440-
expect(typeof ctxArg.getConversationHistory).toBe('function');
436+
expect(runSpy).toHaveBeenCalledTimes(1);
437+
const [ctxArgRaw, dataArg] = runSpy.mock.calls[0] as [unknown, string];
438+
const ctxArg = ctxArgRaw as { getConversationHistory?: () => unknown[] };
439+
expect(dataArg).toBe('Latest user message with additional context.');
440+
expect(typeof ctxArg.getConversationHistory).toBe('function');
441441

442-
const history = ctxArg.getConversationHistory?.() as Array<{ content?: unknown }> | undefined;
443-
expect(Array.isArray(history)).toBe(true);
444-
expect(history && history[history.length - 1]?.content).toBe(
445-
'Latest user message with additional context.'
446-
);
442+
const history = ctxArg.getConversationHistory?.() as Array<{ content?: unknown }> | undefined;
443+
expect(Array.isArray(history)).toBe(true);
444+
expect(history && history[history.length - 1]?.content).toBe(
445+
'Latest user message with additional context.'
446+
);
447447

448-
expect(result.tripwireTriggered).toBe(false);
449-
expect(result.outputInfo.input).toBe('Latest user message with additional context.');
450-
});
448+
expect(result.tripwireTriggered).toBe(false);
449+
expect(result.outputInfo.input).toBe('Latest user message with additional context.');
450+
});
451451

452-
it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => {
452+
it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => {
453453
process.env.OPENAI_API_KEY = 'test';
454454
const config = {
455455
version: 1,

src/agents.ts

Lines changed: 64 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -249,29 +249,50 @@ function ensureGuardrailContext(
249249
} as GuardrailLLMContext;
250250
}
251251

252-
function extractTextFromContentParts(content: unknown): string {
253-
if (typeof content === 'string') {
254-
return content.trim();
252+
const TEXTUAL_CONTENT_TYPES = new Set(['input_text', 'text', 'output_text', 'summary_text']);
253+
const MAX_CONTENT_EXTRACTION_DEPTH = 10;
254+
255+
/**
256+
* Extract text from any nested content value with optional type filtering.
257+
*
258+
* @param value Arbitrary content value (string, array, or object) to inspect.
259+
* @param depth Current recursion depth, used to guard against circular structures.
260+
* @param filterByType When true, only content parts with recognised text types are returned.
261+
* @returns The extracted text, or an empty string when no text is found.
262+
*/
263+
function extractTextFromValue(value: unknown, depth: number, filterByType: boolean): string {
264+
if (depth > MAX_CONTENT_EXTRACTION_DEPTH) {
265+
return '';
266+
}
267+
268+
if (typeof value === 'string') {
269+
return value.trim();
255270
}
256271

257-
if (Array.isArray(content)) {
272+
if (Array.isArray(value)) {
258273
const parts: string[] = [];
259-
for (const item of content) {
260-
const text = extractTextFromMessageEntry(item);
274+
for (const item of value) {
275+
const text = extractTextFromValue(item, depth + 1, filterByType);
261276
if (text) {
262277
parts.push(text);
263278
}
264279
}
265280
return parts.join(' ').trim();
266281
}
267282

268-
if (content && typeof content === 'object') {
269-
const record = content as Record<string, unknown>;
283+
if (value && typeof value === 'object') {
284+
const record = value as Record<string, unknown>;
285+
const typeValue = typeof record.type === 'string' ? record.type : null;
286+
const isRecognisedTextType = typeValue ? TEXTUAL_CONTENT_TYPES.has(typeValue) : false;
287+
270288
if (typeof record.text === 'string') {
271-
return record.text.trim();
289+
if (!filterByType || isRecognisedTextType) {
290+
return record.text.trim();
291+
}
272292
}
293+
273294
if (record.content !== undefined) {
274-
const nested = extractTextFromContentParts(record.content);
295+
const nested = extractTextFromValue(record.content, depth + 1, filterByType);
275296
if (nested) {
276297
return nested;
277298
}
@@ -281,7 +302,27 @@ function extractTextFromContentParts(content: unknown): string {
281302
return '';
282303
}
283304

284-
function extractTextFromMessageEntry(entry: unknown): string {
305+
/**
306+
* Extract text from structured content parts (e.g., the `content` field on a message).
307+
*
308+
* Only recognised textual content-part types are considered to match the behaviour of
309+
* `ContentUtils.filterToTextOnly`, ensuring non-text modalities are ignored.
310+
*/
311+
function extractTextFromContentParts(content: unknown, depth = 0): string {
312+
return extractTextFromValue(content, depth, true);
313+
}
314+
315+
/**
316+
* Extract text from a single message entry.
317+
*
318+
* Handles strings, arrays of content parts, or message-like objects that contain a
319+
* `content` collection or a plain `text` field.
320+
*/
321+
function extractTextFromMessageEntry(entry: unknown, depth = 0): string {
322+
if (depth > MAX_CONTENT_EXTRACTION_DEPTH) {
323+
return '';
324+
}
325+
285326
if (entry == null) {
286327
return '';
287328
}
@@ -291,13 +332,14 @@ function extractTextFromMessageEntry(entry: unknown): string {
291332
}
292333

293334
if (Array.isArray(entry)) {
294-
return extractTextFromContentParts(entry);
335+
return extractTextFromContentParts(entry, depth + 1);
295336
}
296337

297338
if (typeof entry === 'object') {
298339
const record = entry as Record<string, unknown>;
340+
299341
if (record.content !== undefined) {
300-
const contentText = extractTextFromContentParts(record.content);
342+
const contentText = extractTextFromContentParts(record.content, depth + 1);
301343
if (contentText) {
302344
return contentText;
303345
}
@@ -308,9 +350,15 @@ function extractTextFromMessageEntry(entry: unknown): string {
308350
}
309351
}
310352

311-
return '';
353+
return extractTextFromValue(entry, depth + 1, false);
312354
}
313355

356+
/**
357+
* Extract the latest user-authored text from raw agent input.
358+
*
359+
* Accepts strings, message objects, or arrays of mixed items. Arrays are scanned
360+
* from newest to oldest, returning the first user-role message with textual content.
361+
*/
314362
function extractTextFromAgentInput(input: unknown): string {
315363
if (typeof input === 'string') {
316364
return input.trim();
@@ -327,8 +375,8 @@ function extractTextFromAgentInput(input: unknown): string {
327375
return text;
328376
}
329377
}
330-
} else {
331-
const text = extractTextFromMessageEntry(candidate);
378+
} else if (typeof candidate === 'string') {
379+
const text = candidate.trim();
332380
if (text) {
333381
return text;
334382
}

0 commit comments

Comments
 (0)