Skip to content

Commit 154d545

Browse files
committed
Agent parse conversation history
1 parent 54e5806 commit 154d545

File tree

2 files changed

+209
-21
lines changed

2 files changed

+209
-21
lines changed

src/__tests__/unit/agents.test.ts

Lines changed: 89 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -358,14 +358,98 @@ describe('GuardrailAgent', () => {
358358

359359
// Test the guardrail function
360360
const guardrailFunction = agent.inputGuardrails[0];
361-
const result = await guardrailFunction.execute('test input');
361+
const result = await guardrailFunction.execute('test input');
362362

363-
expect(result).toHaveProperty('outputInfo');
364-
expect(result).toHaveProperty('tripwireTriggered');
365-
expect(typeof result.tripwireTriggered).toBe('boolean');
363+
expect(result).toHaveProperty('outputInfo');
364+
expect(result).toHaveProperty('tripwireTriggered');
365+
expect(typeof result.tripwireTriggered).toBe('boolean');
366+
});
367+
368+
it('passes the latest user message text to guardrails for conversation inputs', async () => {
369+
process.env.OPENAI_API_KEY = 'test';
370+
const config = {
371+
version: 1,
372+
input: {
373+
version: 1,
374+
guardrails: [{ name: 'Moderation', config: {} }],
375+
},
376+
};
377+
378+
const { instantiateGuardrails } = await import('../../runtime');
379+
const runSpy = vi.fn().mockResolvedValue({
380+
tripwireTriggered: false,
381+
info: { guardrail_name: 'Moderation' },
366382
});
367383

368-
it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => {
384+
vi.mocked(instantiateGuardrails).mockImplementationOnce(() =>
385+
Promise.resolve([
386+
{
387+
definition: {
388+
name: 'Moderation',
389+
description: 'Moderation guardrail',
390+
mediaType: 'text/plain',
391+
configSchema: z.object({}),
392+
checkFn: vi.fn(),
393+
metadata: {},
394+
ctxRequirements: z.object({}),
395+
schema: () => ({}),
396+
instantiate: vi.fn(),
397+
},
398+
config: {},
399+
run: runSpy,
400+
} as unknown as Parameters<typeof instantiateGuardrails>[0] extends Promise<infer T>
401+
? T extends readonly (infer U)[]
402+
? U
403+
: never
404+
: never,
405+
])
406+
);
407+
408+
const agent = (await GuardrailAgent.create(
409+
config,
410+
'Conversation Agent',
411+
'Handle multi-turn conversations'
412+
)) as MockAgent;
413+
414+
const guardrail = agent.inputGuardrails[0] as unknown as {
415+
execute: (args: { input: unknown; context?: unknown }) => Promise<{
416+
outputInfo: Record<string, unknown>;
417+
tripwireTriggered: boolean;
418+
}>;
419+
};
420+
421+
const conversation = [
422+
{ role: 'system', content: 'You are helpful.' },
423+
{ role: 'user', content: [{ type: 'input_text', text: 'First question?' }] },
424+
{ role: 'assistant', content: [{ type: 'output_text', text: 'An answer.' }] },
425+
{
426+
role: 'user',
427+
content: [
428+
{ type: 'input_text', text: 'Latest user message' },
429+
{ type: 'input_text', text: 'with additional context.' },
430+
],
431+
},
432+
];
433+
434+
const result = await guardrail.execute({ input: conversation, context: {} });
435+
436+
expect(runSpy).toHaveBeenCalledTimes(1);
437+
const [ctxArgRaw, dataArg] = runSpy.mock.calls[0] as [unknown, string];
438+
const ctxArg = ctxArgRaw as { getConversationHistory?: () => unknown[] };
439+
expect(dataArg).toBe('Latest user message with additional context.');
440+
expect(typeof ctxArg.getConversationHistory).toBe('function');
441+
442+
const history = ctxArg.getConversationHistory?.() as Array<{ content?: unknown }> | undefined;
443+
expect(Array.isArray(history)).toBe(true);
444+
expect(history && history[history.length - 1]?.content).toBe(
445+
'Latest user message with additional context.'
446+
);
447+
448+
expect(result.tripwireTriggered).toBe(false);
449+
expect(result.outputInfo.input).toBe('Latest user message with additional context.');
450+
});
451+
452+
it('should handle guardrail execution errors based on raiseGuardrailErrors setting', async () => {
369453
process.env.OPENAI_API_KEY = 'test';
370454
const config = {
371455
version: 1,

src/agents.ts

Lines changed: 120 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,7 @@ import type {
1313
InputGuardrailFunctionArgs,
1414
OutputGuardrailFunctionArgs,
1515
} from '@openai/agents-core';
16-
import { GuardrailLLMContext, GuardrailResult, TextOnlyContent, ContentPart } from './types';
17-
import { ContentUtils } from './utils/content';
16+
import { GuardrailLLMContext, GuardrailResult, TextOnlyContent } from './types';
1817
import {
1918
loadPipelineBundles,
2019
instantiateGuardrails,
@@ -250,6 +249,122 @@ function ensureGuardrailContext(
250249
} as GuardrailLLMContext;
251250
}
252251

252+
function extractTextFromContentParts(content: unknown): string {
253+
if (typeof content === 'string') {
254+
return content.trim();
255+
}
256+
257+
if (Array.isArray(content)) {
258+
const parts: string[] = [];
259+
for (const item of content) {
260+
const text = extractTextFromMessageEntry(item);
261+
if (text) {
262+
parts.push(text);
263+
}
264+
}
265+
return parts.join(' ').trim();
266+
}
267+
268+
if (content && typeof content === 'object') {
269+
const record = content as Record<string, unknown>;
270+
if (typeof record.text === 'string') {
271+
return record.text.trim();
272+
}
273+
if (record.content !== undefined) {
274+
const nested = extractTextFromContentParts(record.content);
275+
if (nested) {
276+
return nested;
277+
}
278+
}
279+
}
280+
281+
return '';
282+
}
283+
284+
function extractTextFromMessageEntry(entry: unknown): string {
285+
if (entry == null) {
286+
return '';
287+
}
288+
289+
if (typeof entry === 'string') {
290+
return entry.trim();
291+
}
292+
293+
if (Array.isArray(entry)) {
294+
return extractTextFromContentParts(entry);
295+
}
296+
297+
if (typeof entry === 'object') {
298+
const record = entry as Record<string, unknown>;
299+
if (record.content !== undefined) {
300+
const contentText = extractTextFromContentParts(record.content);
301+
if (contentText) {
302+
return contentText;
303+
}
304+
}
305+
306+
if (typeof record.text === 'string') {
307+
return record.text.trim();
308+
}
309+
}
310+
311+
return '';
312+
}
313+
314+
function extractTextFromAgentInput(input: unknown): string {
315+
if (typeof input === 'string') {
316+
return input.trim();
317+
}
318+
319+
if (Array.isArray(input)) {
320+
for (let idx = input.length - 1; idx >= 0; idx -= 1) {
321+
const candidate = input[idx];
322+
if (candidate && typeof candidate === 'object') {
323+
const record = candidate as Record<string, unknown>;
324+
if (record.role === 'user') {
325+
const text = extractTextFromMessageEntry(candidate);
326+
if (text) {
327+
return text;
328+
}
329+
}
330+
} else {
331+
const text = extractTextFromMessageEntry(candidate);
332+
if (text) {
333+
return text;
334+
}
335+
}
336+
}
337+
return '';
338+
}
339+
340+
if (input && typeof input === 'object') {
341+
const record = input as Record<string, unknown>;
342+
if (record.role === 'user') {
343+
const text = extractTextFromMessageEntry(record);
344+
if (text) {
345+
return text;
346+
}
347+
}
348+
349+
if (record.content !== undefined) {
350+
const contentText = extractTextFromContentParts(record.content);
351+
if (contentText) {
352+
return contentText;
353+
}
354+
}
355+
356+
if (typeof record.text === 'string') {
357+
return record.text.trim();
358+
}
359+
}
360+
361+
if (input == null) {
362+
return '';
363+
}
364+
365+
return String(input);
366+
}
367+
253368
function extractLatestUserText(history: NormalizedConversationEntry[]): string {
254369
for (let i = history.length - 1; i >= 0; i -= 1) {
255370
const entry = history[i];
@@ -261,20 +376,9 @@ function extractLatestUserText(history: NormalizedConversationEntry[]): string {
261376
}
262377

263378
function resolveInputText(input: unknown, history: NormalizedConversationEntry[]): string {
264-
if (typeof input === 'string') {
265-
return input;
266-
}
267-
268-
if (input && typeof input === 'object' && 'content' in (input as Record<string, unknown>)) {
269-
const content = (input as { content: string | ContentPart[] }).content;
270-
const message = {
271-
role: 'user',
272-
content,
273-
};
274-
const extracted = ContentUtils.extractTextFromMessage(message);
275-
if (extracted) {
276-
return extracted;
277-
}
379+
const directText = extractTextFromAgentInput(input);
380+
if (directText) {
381+
return directText;
278382
}
279383

280384
return extractLatestUserText(history);

0 commit comments

Comments
 (0)