Skip to content

Commit 8d783f8

Browse files
fix(assistant): changed to a basic "run resolver" to avoid issues with many function callings
1 parent 209abbb commit 8d783f8

File tree

9 files changed

+132
-75
lines changed

9 files changed

+132
-75
lines changed

.env.dist

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
OPENAI_API_KEY=
33
# Assistant ID - leave it empty if you don't have an assistant yet
44
ASSISTANT_ID=
5+
ASSISTANT_IS_LOGGER_ENABLED=
56

67
# Agents:
78
# -------------------------------------------------------------------

libs/openai-assistant/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@boldare/openai-assistant",
33
"description": "NestJS library for building chatbot solutions based on the OpenAI Assistant API",
4-
"version": "1.0.0",
4+
"version": "1.0.2",
55
"private": false,
66
"dependencies": {
77
"tslib": "^2.3.0",

libs/openai-assistant/src/lib/assistant/assistant.service.spec.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ describe('AssistantService', () => {
7474
jest
7575
.spyOn(aiService.provider.beta.assistants, 'update')
7676
.mockRejectedValueOnce('error');
77-
jest.spyOn(assistantService, 'create').mockResolvedValueOnce(undefined);
77+
jest.spyOn(assistantService, 'create').mockResolvedValueOnce({} as Assistant);
7878

7979
await assistantService.init();
8080

@@ -97,7 +97,7 @@ describe('AssistantService', () => {
9797
.spyOn(configService, 'get')
9898
.mockReturnValue({ ...assistantConfigMock, id: '' });
9999

100-
jest.spyOn(assistantService, 'create').mockResolvedValueOnce(undefined);
100+
jest.spyOn(assistantService, 'create').mockResolvedValueOnce({} as Assistant);
101101

102102
await assistantService.init();
103103

libs/openai-assistant/src/lib/assistant/assistant.service.ts

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ export class AssistantService {
3030
};
3131
}
3232

33-
async init(): Promise<void> {
33+
async init(): Promise<Assistant> {
3434
const { id, options } = this.assistantConfig.get();
3535

3636
if (!id) {
@@ -43,16 +43,17 @@ export class AssistantService {
4343
this.getParams(),
4444
options,
4545
);
46+
return this.assistant;
4647
} catch (e) {
47-
await this.create();
48+
return await this.create();
4849
}
4950
}
5051

5152
async update(params: Partial<AssistantCreateParams>): Promise<void> {
5253
this.assistant = await this.assistants.update(this.assistant.id, params);
5354
}
5455

55-
async create(): Promise<void> {
56+
async create(): Promise<Assistant> {
5657
const { options } = this.assistantConfig.get();
5758
const params = this.getParams();
5859
this.assistant = await this.assistants.create(params, options);
@@ -63,6 +64,8 @@ export class AssistantService {
6364

6465
this.logger.log(`Created new assistant (${this.assistant.id})`);
6566
await this.assistantMemoryService.saveAssistantId(this.assistant.id);
67+
68+
return this.assistant;
6669
}
6770

6871
async updateFiles(fileNames?: string[]): Promise<Assistant> {

libs/openai-assistant/src/lib/chat/chat.gateway.ts

Lines changed: 54 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -39,28 +39,54 @@ export class ChatGateway implements OnGatewayConnection {
3939
this.logger = new Logger(ChatGateway.name);
4040
}
4141

42+
log(message: string): void {
43+
try {
44+
const isLoggerEnabled: string = JSON.parse(
45+
(process.env['ASSISTANT_IS_LOGGER_ENABLED'] || 'false').toLowerCase(),
46+
);
47+
48+
if (isLoggerEnabled) {
49+
this.logger.log(message);
50+
}
51+
} catch (error) {
52+
this.logger.error('"ASSISTANT_IS_LOGGER_ENABLED" should be boolean');
53+
}
54+
}
55+
4256
async handleConnection() {
43-
this.logger.log('Client connected');
57+
this.log('Client connected');
4458
}
4559

4660
getCallbacks(socketId: string): ChatCallCallbacks {
4761
return {
48-
[ChatEvents.MessageCreated]: this.emitMessageCreated.bind(this, socketId),
49-
[ChatEvents.MessageDelta]: this.emitMessageDelta.bind(this, socketId),
50-
[ChatEvents.MessageDone]: this.emitMessageDone.bind(this, socketId),
51-
[ChatEvents.TextCreated]: this.emitTextCreated.bind(this, socketId),
52-
[ChatEvents.TextDelta]: this.emitTextDelta.bind(this, socketId),
53-
[ChatEvents.TextDone]: this.emitTextDone.bind(this, socketId),
62+
[ChatEvents.MessageCreated]: eventData =>
63+
this.emitMessageCreated(socketId, eventData),
64+
[ChatEvents.MessageDelta]: eventData =>
65+
this.emitMessageDelta(socketId, eventData),
66+
[ChatEvents.MessageDone]: eventData =>
67+
this.emitMessageDone(socketId, eventData),
68+
[ChatEvents.TextCreated]: eventData =>
69+
this.emitTextCreated(socketId, eventData),
70+
[ChatEvents.TextDelta]: eventData =>
71+
this.emitTextDelta(socketId, eventData),
72+
[ChatEvents.TextDone]: eventData =>
73+
this.emitTextDone(socketId, eventData),
5474
[ChatEvents.ToolCallCreated]: this.emitToolCallCreated.bind(
5575
this,
5676
socketId,
5777
),
58-
[ChatEvents.ToolCallDelta]: this.emitToolCallDelta.bind(this, socketId),
59-
[ChatEvents.ToolCallDone]: this.emitToolCallDone.bind(this, socketId),
60-
[ChatEvents.ImageFileDone]: this.emitImageFileDone.bind(this, socketId),
61-
[ChatEvents.RunStepCreated]: this.emitRunStepCreated.bind(this, socketId),
62-
[ChatEvents.RunStepDelta]: this.emitRunStepDelta.bind(this, socketId),
63-
[ChatEvents.RunStepDone]: this.emitRunStepDone.bind(this, socketId),
78+
[ChatEvents.ToolCallDelta]: eventData =>
79+
this.emitToolCallDelta(socketId, eventData),
80+
[ChatEvents.ToolCallDone]: eventData =>
81+
this.emitToolCallDone(socketId, eventData),
82+
[ChatEvents.ImageFileDone]: eventData =>
83+
this.emitImageFileDone(socketId, eventData),
84+
[ChatEvents.RunStepCreated]: eventData =>
85+
this.emitRunStepCreated(socketId, eventData),
86+
[ChatEvents.RunStepDelta]: eventData =>
87+
this.emitRunStepDelta(socketId, eventData),
88+
[ChatEvents.RunStepDone]: eventData =>
89+
this.emitRunStepDone(socketId, eventData),
6490
};
6591
}
6692

@@ -69,15 +95,15 @@ export class ChatGateway implements OnGatewayConnection {
6995
@MessageBody() request: ChatCallDto,
7096
@ConnectedSocket() socket: Socket,
7197
) {
72-
this.logger.log(
98+
this.log(
7399
`Socket "${ChatEvents.CallStart}" | threadId ${request.threadId} | files: ${request?.file_ids?.join(', ')} | content: ${request.content}`,
74100
);
75101

76102
const callbacks: ChatCallCallbacks = this.getCallbacks(socket.id);
77103
const message = await this.chatsService.call(request, callbacks);
78104

79105
this.server?.to(socket.id).emit(ChatEvents.CallDone, message);
80-
this.logger.log(
106+
this.log(
81107
`Socket "${ChatEvents.CallDone}" | threadId ${message.threadId} | content: ${message.content}`,
82108
);
83109
}
@@ -87,7 +113,7 @@ export class ChatGateway implements OnGatewayConnection {
87113
@MessageBody() data: MessageCreatedPayload,
88114
) {
89115
this.server.to(socketId).emit(ChatEvents.MessageCreated, data);
90-
this.logger.log(
116+
this.log(
91117
`Socket "${ChatEvents.MessageCreated}" | threadId: ${data.message.thread_id}`,
92118
);
93119
}
@@ -97,7 +123,7 @@ export class ChatGateway implements OnGatewayConnection {
97123
@MessageBody() data: MessageDeltaPayload,
98124
) {
99125
this.server.to(socketId).emit(ChatEvents.MessageDelta, data);
100-
this.logger.log(
126+
this.log(
101127
`Socket "${ChatEvents.MessageDelta}" | threadId: ${data.message.thread_id}`,
102128
);
103129
}
@@ -107,7 +133,7 @@ export class ChatGateway implements OnGatewayConnection {
107133
@MessageBody() data: MessageDonePayload,
108134
) {
109135
this.server.to(socketId).emit(ChatEvents.MessageDone, data);
110-
this.logger.log(
136+
this.log(
111137
`Socket "${ChatEvents.MessageDone}" | threadId: ${data.message.thread_id}`,
112138
);
113139
}
@@ -117,19 +143,17 @@ export class ChatGateway implements OnGatewayConnection {
117143
@MessageBody() data: TextCreatedPayload,
118144
) {
119145
this.server.to(socketId).emit(ChatEvents.TextCreated, data);
120-
this.logger.log(`Socket "${ChatEvents.TextCreated}" | ${data.text.value}`);
146+
this.log(`Socket "${ChatEvents.TextCreated}" | ${data.text.value}`);
121147
}
122148

123149
async emitTextDelta(socketId: string, @MessageBody() data: TextDeltaPayload) {
124150
this.server.to(socketId).emit(ChatEvents.TextDelta, data);
125-
this.logger.log(
126-
`Socket "${ChatEvents.TextDelta}" | ${data.textDelta.value}`,
127-
);
151+
this.log(`Socket "${ChatEvents.TextDelta}" | ${data.textDelta.value}`);
128152
}
129153

130154
async emitTextDone(socketId: string, @MessageBody() data: TextDonePayload) {
131155
this.server.to(socketId).emit(ChatEvents.TextDone, data);
132-
this.logger.log(
156+
this.log(
133157
`Socket "${ChatEvents.TextDone}" | threadId: ${data.message?.thread_id} | ${data.text.value}`,
134158
);
135159
}
@@ -139,9 +163,7 @@ export class ChatGateway implements OnGatewayConnection {
139163
@MessageBody() data: ToolCallCreatedPayload,
140164
) {
141165
this.server.to(socketId).emit(ChatEvents.ToolCallCreated, data);
142-
this.logger.log(
143-
`Socket "${ChatEvents.ToolCallCreated}": ${data.toolCall.id}`,
144-
);
166+
this.log(`Socket "${ChatEvents.ToolCallCreated}": ${data.toolCall.id}`);
145167
}
146168

147169
codeInterpreterHandler(
@@ -185,9 +207,7 @@ export class ChatGateway implements OnGatewayConnection {
185207
socketId: string,
186208
@MessageBody() data: ToolCallDeltaPayload,
187209
) {
188-
this.logger.log(
189-
`Socket "${ChatEvents.ToolCallDelta}": ${data.toolCall.id}`,
190-
);
210+
this.log(`Socket "${ChatEvents.ToolCallDelta}": ${data.toolCall.id}`);
191211

192212
switch (data.toolCallDelta.type) {
193213
case 'code_interpreter':
@@ -211,46 +231,38 @@ export class ChatGateway implements OnGatewayConnection {
211231
@MessageBody() data: ToolCallDonePayload,
212232
) {
213233
this.server.to(socketId).emit(ChatEvents.ToolCallDone, data);
214-
this.logger.log(`Socket "${ChatEvents.ToolCallDone}": ${data.toolCall.id}`);
234+
this.log(`Socket "${ChatEvents.ToolCallDone}": ${data.toolCall.id}`);
215235
}
216236

217237
async emitImageFileDone(
218238
socketId: string,
219239
@MessageBody() data: ImageFileDonePayload,
220240
) {
221241
this.server.to(socketId).emit(ChatEvents.ImageFileDone, data);
222-
this.logger.log(
223-
`Socket "${ChatEvents.ImageFileDone}": ${data.content.file_id}`,
224-
);
242+
this.log(`Socket "${ChatEvents.ImageFileDone}": ${data.content.file_id}`);
225243
}
226244

227245
async emitRunStepCreated(
228246
socketId: string,
229247
@MessageBody() data: RunStepCreatedPayload,
230248
) {
231249
this.server.to(socketId).emit(ChatEvents.RunStepCreated, data);
232-
this.logger.log(
233-
`Socket "${ChatEvents.RunStepCreated}": ${data.runStep.status}`,
234-
);
250+
this.log(`Socket "${ChatEvents.RunStepCreated}": ${data.runStep.status}`);
235251
}
236252

237253
async emitRunStepDelta(
238254
socketId: string,
239255
@MessageBody() data: RunStepDeltaPayload,
240256
) {
241257
this.server.to(socketId).emit(ChatEvents.RunStepDelta, data);
242-
this.logger.log(
243-
`Socket "${ChatEvents.RunStepDelta}": ${data.runStep.status}`,
244-
);
258+
this.log(`Socket "${ChatEvents.RunStepDelta}": ${data.runStep.status}`);
245259
}
246260

247261
async emitRunStepDone(
248262
socketId: string,
249263
@MessageBody() data: RunStepDonePayload,
250264
) {
251265
this.server.to(socketId).emit(ChatEvents.RunStepDone, data);
252-
this.logger.log(
253-
`Socket "${ChatEvents.RunStepDone}": ${data.runStep.status}`,
254-
);
266+
this.log(`Socket "${ChatEvents.RunStepDone}": ${data.runStep.status}`);
255267
}
256268
}

libs/openai-assistant/src/lib/chat/chat.service.spec.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,22 @@
11
import { Test } from '@nestjs/testing';
22
import { APIPromise } from 'openai/core';
33
import { Message, Run } from 'openai/resources/beta/threads';
4+
import { AssistantStream } from 'openai/lib/AssistantStream';
45
import { AiModule } from './../ai/ai.module';
56
import { ChatModule } from './chat.module';
67
import { ChatService } from './chat.service';
78
import { ChatHelpers } from './chat.helpers';
89
import { ChatCallDto } from './chat.model';
9-
import { AssistantStream } from 'openai/lib/AssistantStream';
10+
import { RunService } from '../run/run.service';
11+
12+
jest.mock('../stream/stream.utils', () => ({
13+
assistantStreamEventHandler: jest.fn(),
14+
}));
1015

1116
describe('ChatService', () => {
1217
let chatService: ChatService;
1318
let chatbotHelpers: ChatHelpers;
19+
let runService: RunService;
1420

1521
beforeEach(async () => {
1622
const moduleRef = await Test.createTestingModule({
@@ -19,18 +25,21 @@ describe('ChatService', () => {
1925

2026
chatService = moduleRef.get<ChatService>(ChatService);
2127
chatbotHelpers = moduleRef.get<ChatHelpers>(ChatHelpers);
28+
runService = moduleRef.get<RunService>(RunService);
29+
30+
jest.spyOn(runService, 'resolve').mockReturnThis();
2231

2332
jest
2433
.spyOn(chatbotHelpers, 'getAnswer')
2534
.mockReturnValue(Promise.resolve('Hello response') as Promise<string>);
2635

27-
2836
jest
2937
.spyOn(chatService.threads.messages, 'create')
3038
.mockReturnValue({} as APIPromise<Message>);
3139

32-
jest.spyOn(chatService, 'assistantStream').mockReturnValue({
40+
jest.spyOn(chatService, 'getAssistantStream').mockReturnValue({
3341
finalRun: jest.fn(),
42+
on: () => jest.fn(),
3443
} as unknown as Promise<AssistantStream>);
3544
});
3645

libs/openai-assistant/src/lib/chat/chat.service.ts

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import {
77
ChatCallResponseDto,
88
} from './chat.model';
99
import { ChatHelpers } from './chat.helpers';
10-
import { MessageCreateParams, Run } from 'openai/resources/beta/threads';
10+
import { Message, MessageCreateParams } from 'openai/resources/beta/threads';
1111
import { AssistantStream } from 'openai/lib/AssistantStream';
1212
import { assistantStreamEventHandler } from '../stream/stream.utils';
1313

@@ -26,6 +26,21 @@ export class ChatService {
2626
payload: ChatCallDto,
2727
callbacks?: ChatCallCallbacks,
2828
): Promise<ChatCallResponseDto> {
29+
await this.createMessage(payload);
30+
31+
const runner = await this.getAssistantStream(payload);
32+
assistantStreamEventHandler<AssistantStream>(runner, callbacks);
33+
34+
const finalRun = await runner.finalRun();
35+
await this.runService.resolve(await runner.finalRun(), true, callbacks);
36+
37+
return {
38+
content: await this.chatbotHelpers.getAnswer(finalRun),
39+
threadId: payload.threadId,
40+
};
41+
}
42+
43+
async createMessage(payload: ChatCallDto): Promise<Message> {
2944
const { threadId, content, file_ids, metadata } = payload;
3045
const message: MessageCreateParams = {
3146
role: 'user',
@@ -34,32 +49,15 @@ export class ChatService {
3449
metadata,
3550
};
3651

37-
await this.threads.messages.create(threadId, message);
38-
39-
const runner = await this.assistantStream(payload, callbacks);
40-
const finalRun = await runner.finalRun();
41-
42-
return {
43-
content: await this.chatbotHelpers.getAnswer(finalRun),
44-
threadId,
45-
};
52+
return this.threads.messages.create(threadId, message);
4653
}
4754

48-
async assistantStream(
49-
payload: ChatCallDto,
50-
callbacks?: ChatCallCallbacks,
51-
): Promise<AssistantStream> {
55+
async getAssistantStream(payload: ChatCallDto): Promise<AssistantStream> {
5256
const assistant_id =
5357
payload?.assistantId || process.env['ASSISTANT_ID'] || '';
5458

55-
const runner = this.threads.runs
56-
.createAndStream(payload.threadId, { assistant_id })
57-
.on('event', event => {
58-
if (event.event === 'thread.run.requires_action') {
59-
this.runService.submitAction(event.data, callbacks);
60-
}
61-
});
62-
63-
return assistantStreamEventHandler<AssistantStream>(runner, callbacks);
59+
return this.threads.runs.createAndStream(payload.threadId, {
60+
assistant_id,
61+
});
6462
}
6563
}

0 commit comments

Comments
 (0)