Skip to content

Commit 209abbb

Browse files
feat(openai-assistant): handled "requires_action" event
1 parent 059e283 commit 209abbb

File tree

4 files changed

+19
-133
lines changed

4 files changed

+19
-133
lines changed

libs/openai-assistant/src/lib/chat/chat.service.spec.ts

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,12 @@ import { AiModule } from './../ai/ai.module';
55
import { ChatModule } from './chat.module';
66
import { ChatService } from './chat.service';
77
import { ChatHelpers } from './chat.helpers';
8-
import { RunService } from '../run';
98
import { ChatCallDto } from './chat.model';
109
import { AssistantStream } from 'openai/lib/AssistantStream';
1110

1211
describe('ChatService', () => {
1312
let chatService: ChatService;
1413
let chatbotHelpers: ChatHelpers;
15-
let runService: RunService;
1614

1715
beforeEach(async () => {
1816
const moduleRef = await Test.createTestingModule({
@@ -21,23 +19,19 @@ describe('ChatService', () => {
2119

2220
chatService = moduleRef.get<ChatService>(ChatService);
2321
chatbotHelpers = moduleRef.get<ChatHelpers>(ChatHelpers);
24-
runService = moduleRef.get<RunService>(RunService);
2522

2623
jest
2724
.spyOn(chatbotHelpers, 'getAnswer')
2825
.mockReturnValue(Promise.resolve('Hello response') as Promise<string>);
2926

30-
jest.spyOn(runService, 'resolve').mockReturnThis();
3127

3228
jest
3329
.spyOn(chatService.threads.messages, 'create')
3430
.mockReturnValue({} as APIPromise<Message>);
3531

3632
jest.spyOn(chatService, 'assistantStream').mockReturnValue({
37-
finalRun(): Promise<Run> {
38-
return Promise.resolve({} as Run);
39-
},
40-
} as AssistantStream);
33+
finalRun: jest.fn(),
34+
} as unknown as Promise<AssistantStream>);
4135
});
4236

4337
it('should be defined', () => {

libs/openai-assistant/src/lib/chat/chat.service.ts

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {
77
ChatCallResponseDto,
88
} from './chat.model';
99
import { ChatHelpers } from './chat.helpers';
10-
import { MessageCreateParams } from 'openai/resources/beta/threads';
10+
import { MessageCreateParams, Run } from 'openai/resources/beta/threads';
1111
import { AssistantStream } from 'openai/lib/AssistantStream';
1212
import { assistantStreamEventHandler } from '../stream/stream.utils';
1313

@@ -36,27 +36,29 @@ export class ChatService {
3636

3737
await this.threads.messages.create(threadId, message);
3838

39-
const assistantId =
40-
payload?.assistantId || process.env['ASSISTANT_ID'] || '';
41-
const run = this.assistantStream(assistantId, threadId, callbacks);
42-
const finalRun = await run.finalRun();
43-
44-
await this.runService.resolve(finalRun, true, callbacks);
39+
const runner = await this.assistantStream(payload, callbacks);
40+
const finalRun = await runner.finalRun();
4541

4642
return {
4743
content: await this.chatbotHelpers.getAnswer(finalRun),
4844
threadId,
4945
};
5046
}
5147

52-
assistantStream(
53-
assistantId: string,
54-
threadId: string,
48+
async assistantStream(
49+
payload: ChatCallDto,
5550
callbacks?: ChatCallCallbacks,
56-
): AssistantStream {
57-
const runner = this.threads.runs.createAndStream(threadId, {
58-
assistant_id: assistantId,
59-
});
51+
): Promise<AssistantStream> {
52+
const assistant_id =
53+
payload?.assistantId || process.env['ASSISTANT_ID'] || '';
54+
55+
const runner = this.threads.runs
56+
.createAndStream(payload.threadId, { assistant_id })
57+
.on('event', event => {
58+
if (event.event === 'thread.run.requires_action') {
59+
this.runService.submitAction(event.data, callbacks);
60+
}
61+
});
6062

6163
return assistantStreamEventHandler<AssistantStream>(runner, callbacks);
6264
}

libs/openai-assistant/src/lib/run/run.service.spec.ts

Lines changed: 0 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -33,80 +33,6 @@ describe('RunService', () => {
3333
expect(runService).toBeDefined();
3434
});
3535

36-
describe('continueRun', () => {
37-
it('should call threads.runs.retrieve', async () => {
38-
const spyOnRetrieve = jest
39-
.spyOn(aiService.provider.beta.threads.runs, 'retrieve')
40-
.mockReturnThis();
41-
const run = { thread_id: '1', id: '123' } as Run;
42-
43-
await runService.continueRun(run);
44-
45-
expect(spyOnRetrieve).toHaveBeenCalled();
46-
});
47-
48-
it('should wait for timeout', async () => {
49-
const run = { thread_id: '1', id: '123' } as Run;
50-
const spyOnTimeout = jest.spyOn(global, 'setTimeout');
51-
52-
await runService.continueRun(run);
53-
54-
expect(spyOnTimeout).toHaveBeenCalledWith(
55-
expect.any(Function),
56-
runService.timeout,
57-
);
58-
});
59-
});
60-
61-
describe('resolve', () => {
62-
it('should call continueRun', async () => {
63-
const spyOnContinueRun = jest
64-
.spyOn(runService, 'continueRun')
65-
.mockResolvedValue({} as Run);
66-
const run = { status: 'requires_action' } as Run;
67-
68-
await runService.resolve(run, false);
69-
70-
expect(spyOnContinueRun).toHaveBeenCalled();
71-
});
72-
73-
it('should call submitAction', async () => {
74-
const spyOnSubmitAction = jest
75-
.spyOn(runService, 'submitAction')
76-
.mockResolvedValue();
77-
const run = {
78-
status: 'requires_action',
79-
required_action: { type: 'submit_tool_outputs' },
80-
} as Run;
81-
82-
await runService.resolve(run, false);
83-
84-
expect(spyOnSubmitAction).toHaveBeenCalled();
85-
});
86-
87-
it('should call default', async () => {
88-
const spyOnContinueRun = jest
89-
.spyOn(runService, 'continueRun')
90-
.mockResolvedValue({} as Run);
91-
const run = { status: 'unknown' } as unknown as Run;
92-
93-
await runService.resolve(run, false);
94-
95-
expect(spyOnContinueRun).toHaveBeenCalled();
96-
});
97-
98-
it('should not invoke action when status is cancelling', async () => {
99-
const spyOnContinueRun = jest
100-
.spyOn(runService, 'continueRun')
101-
.mockResolvedValue({} as Run);
102-
const run = { status: 'cancelling' } as unknown as Run;
103-
104-
await runService.resolve(run, false);
105-
106-
expect(spyOnContinueRun).not.toHaveBeenCalled();
107-
});
108-
});
109-
11036
describe('submitAction', () => {
11137
it('should call submitToolOutputsStream', async () => {
11238
const spyOnSubmitToolOutputsStream = jest

libs/openai-assistant/src/lib/run/run.service.ts

Lines changed: 1 addition & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
import { Injectable } from '@nestjs/common';
2-
import {
3-
Run,
4-
RunSubmitToolOutputsParams,
5-
Text,
6-
TextDelta,
7-
} from 'openai/resources/beta/threads';
2+
import { Run, RunSubmitToolOutputsParams } from 'openai/resources/beta/threads';
83
import { AiService } from '../ai';
94
import { AgentService } from '../agent';
105
import { ChatCallCallbacks } from '../chat';
@@ -13,43 +8,12 @@ import { assistantStreamEventHandler } from '../stream/stream.utils';
138
@Injectable()
149
export class RunService {
1510
private readonly threads = this.aiService.provider.beta.threads;
16-
timeout = 2000;
17-
isRunning = true;
1811

1912
constructor(
2013
private readonly aiService: AiService,
2114
private readonly agentsService: AgentService,
2215
) {}
2316

24-
async continueRun(run: Run): Promise<Run> {
25-
await new Promise(resolve => setTimeout(resolve, this.timeout));
26-
return this.threads.runs.retrieve(run.thread_id, run.id);
27-
}
28-
29-
async resolve(
30-
run: Run,
31-
runningStatus: boolean,
32-
callbacks?: ChatCallCallbacks,
33-
): Promise<void> {
34-
while (this.isRunning)
35-
switch (run.status) {
36-
case 'cancelling':
37-
case 'cancelled':
38-
case 'failed':
39-
case 'expired':
40-
case 'completed':
41-
return;
42-
case 'requires_action':
43-
await this.submitAction(run, callbacks);
44-
run = await this.continueRun(run);
45-
this.isRunning = runningStatus;
46-
continue;
47-
default:
48-
run = await this.continueRun(run);
49-
this.isRunning = runningStatus;
50-
}
51-
}
52-
5317
async submitAction(run: Run, callbacks?: ChatCallCallbacks): Promise<void> {
5418
if (run.required_action?.type !== 'submit_tool_outputs') {
5519
return;

0 commit comments

Comments
 (0)