Skip to content

Commit 5a2862f

Browse files
authored
feat: replace OpenAI with Vercel AI SDK (#2830)
Initial refactor to support tool/function calling by using ai-adk which supports it. Signed-off-by: Marc Nuri <[email protected]>
1 parent 4a42431 commit 5a2862f

File tree

6 files changed

+629
-226
lines changed

6 files changed

+629
-226
lines changed

packages/backend/package.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,8 +100,10 @@
100100
"typecheck": "pnpm run generate && tsc --noEmit"
101101
},
102102
"dependencies": {
103+
"@ai-sdk/openai-compatible": "^0.2.11",
103104
"@huggingface/gguf": "^0.1.14",
104105
"@huggingface/hub": "^1.1.2",
106+
"ai": "^4.3.6",
105107
"express": "^4.21.2",
106108
"express-openapi-validator": "^5.4.9",
107109
"isomorphic-git": "^1.30.1",
@@ -118,6 +120,7 @@
118120
},
119121
"devDependencies": {
120122
"@podman-desktop/api": "1.13.0-202409181313-78725a6565",
123+
"@ai-sdk/provider-utils": "^2.2.6",
121124
"@rollup/plugin-replace": "^6.0.2",
122125
"@types/express": "^4.17.21",
123126
"@types/js-yaml": "^4.0.9",
Lines changed: 226 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
/**********************************************************************
2+
* Copyright (C) 2025 Red Hat, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* SPDX-License-Identifier: Apache-2.0
17+
***********************************************************************/
18+
19+
import { describe, test, expect, beforeEach, vi } from 'vitest';
20+
import * as ai from 'ai';
21+
import { AiStreamProcessor, toCoreMessage } from './aiSdk';
22+
import type {
23+
AssistantChat,
24+
ChatMessage,
25+
ErrorMessage,
26+
Message,
27+
PendingChat,
28+
UserChat,
29+
} from '@shared/models/IPlaygroundMessage';
30+
import type { LanguageModelV1, LanguageModelV1CallWarning, LanguageModelV1StreamPart } from '@ai-sdk/provider';
31+
// @ts-expect-error this is a test module
32+
import { convertArrayToReadableStream } from '@ai-sdk/provider-utils/test';
33+
import { ConversationRegistry } from '../../registries/ConversationRegistry';
34+
import type { RpcExtension } from '@shared/messages/MessageProxy';
35+
import type { ModelOptions } from '@shared/models/IModelOptions';
36+
37+
vi.mock('ai', async original => {
38+
const mod = (await original()) as object;
39+
return { ...mod };
40+
});
41+
42+
describe('aiSdk', () => {
43+
beforeEach(() => {
44+
vi.resetAllMocks();
45+
});
46+
describe('toCoreMessage', () => {
47+
test('with no fields', () => {
48+
const result = toCoreMessage({} as Message);
49+
expect(result).toEqual([]);
50+
});
51+
test('with no role', () => {
52+
const result = toCoreMessage({ content: 'alex' } as ChatMessage);
53+
expect(result).toEqual([]);
54+
});
55+
test('with no content', () => {
56+
const result = toCoreMessage({ role: 'user' } as ChatMessage);
57+
expect(result).toEqual([{ role: 'user', content: '' }]);
58+
});
59+
test('with all fields', () => {
60+
const result = toCoreMessage({ role: 'user', content: 'alex' } as ChatMessage);
61+
expect(result).toEqual([{ role: 'user', content: 'alex' }]);
62+
});
63+
test('with multiple messages', () => {
64+
const result = toCoreMessage(
65+
{ role: 'user', content: 'alex' } as ChatMessage,
66+
{ role: 'assistant', content: 'bob' } as ChatMessage,
67+
);
68+
expect(result).toEqual([
69+
{ role: 'user', content: 'alex' },
70+
{ role: 'assistant', content: 'bob' },
71+
]);
72+
});
73+
});
74+
describe('AiStreamProcessor', () => {
75+
let conversationRegistry: ConversationRegistry;
76+
let conversationId: string;
77+
beforeEach(() => {
78+
const rpcExtension = {
79+
fire: vi.fn().mockResolvedValue(true),
80+
} as unknown as RpcExtension;
81+
conversationRegistry = new ConversationRegistry(rpcExtension);
82+
conversationId = conversationRegistry.createConversation('test-conversation', 'test-model');
83+
conversationRegistry.submit(conversationId, {
84+
content: 'Aitana, please proceed with the test',
85+
role: 'user',
86+
id: conversationRegistry.getUniqueId(),
87+
timestamp: Date.now(),
88+
} as UserChat);
89+
});
90+
test('sends model options', async () => {
91+
const streamTextSpy = vi.spyOn(ai, 'streamText');
92+
const streamProcessor = new AiStreamProcessor(conversationId, conversationRegistry);
93+
const streamResult = streamProcessor.stream(createTestModel(), {
94+
temperature: 42,
95+
top_p: 13,
96+
max_tokens: 37,
97+
stream_options: { include_usage: true },
98+
} as ModelOptions);
99+
await streamResult.consumeStream();
100+
expect(streamTextSpy).toHaveBeenCalledWith(
101+
expect.objectContaining({
102+
model: expect.anything(),
103+
temperature: 42,
104+
maxTokens: 37,
105+
topP: 13,
106+
abortSignal: expect.any(AbortSignal),
107+
messages: expect.any(Array),
108+
onStepFinish: expect.any(Function),
109+
onError: expect.any(Function),
110+
onChunk: expect.any(Function),
111+
}),
112+
);
113+
});
114+
test('abort, completes the last assistant message', async () => {
115+
const incompleteMessageId = 'incomplete-message-id';
116+
conversationRegistry.submit(conversationId, {
117+
id: incompleteMessageId,
118+
role: 'assistant',
119+
timestamp: Date.now(),
120+
choices: [],
121+
completed: undefined,
122+
} as PendingChat);
123+
const streamProcessor = new AiStreamProcessor(conversationId, conversationRegistry);
124+
streamProcessor['currentMessageId'] = incompleteMessageId;
125+
streamProcessor.abortController.abort('cancel');
126+
expect(conversationRegistry.get(conversationId).messages).toHaveLength(2);
127+
expect((conversationRegistry.get(conversationId).messages[1] as AssistantChat).completed).not.toBeUndefined();
128+
});
129+
describe('with stream error', () => {
130+
beforeEach(async () => {
131+
// eslint-disable-next-line sonarjs/no-nested-functions
132+
const doStream: LanguageModelV1['doStream'] = async () => {
133+
throw new Error('The stream is kaput.');
134+
};
135+
const model = new MockLanguageModelV1({ doStream });
136+
await new AiStreamProcessor(conversationId, conversationRegistry).stream(model).consumeStream();
137+
});
138+
test('appends a single message', () => {
139+
expect(conversationRegistry.get(conversationId).messages).toHaveLength(2);
140+
});
141+
test('appended message is error', () => {
142+
expect((conversationRegistry.get(conversationId).messages[1] as ErrorMessage).error).toEqual(
143+
'The stream is kaput.',
144+
);
145+
});
146+
});
147+
describe('with single message stream', () => {
148+
let model: LanguageModelV1;
149+
beforeEach(async () => {
150+
model = createTestModel({
151+
stream: convertArrayToReadableStream([
152+
{
153+
type: 'response-metadata',
154+
id: 'id-0',
155+
modelId: 'mock-model-id',
156+
timestamp: new Date(0),
157+
},
158+
{ type: 'text-delta', textDelta: 'Greetings' },
159+
{ type: 'text-delta', textDelta: ' professor ' },
160+
{ type: 'text-delta', textDelta: `Falken` },
161+
{ type: 'finish', finishReason: 'stop', usage: { completionTokens: 133, promptTokens: 7 } },
162+
]),
163+
});
164+
await new AiStreamProcessor(conversationId, conversationRegistry).stream(model).consumeStream();
165+
});
166+
test('appends a single message', () => {
167+
expect(conversationRegistry.get(conversationId).messages).toHaveLength(2);
168+
});
169+
test('appended message is from assistant', () => {
170+
expect((conversationRegistry.get(conversationId).messages[1] as ChatMessage).role).toEqual('assistant');
171+
});
172+
test('concatenates message content', () => {
173+
expect((conversationRegistry.get(conversationId).messages[1] as ChatMessage).content).toEqual(
174+
'Greetings professor Falken',
175+
);
176+
});
177+
test('setsUsage', async () => {
178+
await new AiStreamProcessor(conversationId, conversationRegistry).stream(model).consumeStream();
179+
const message = conversationRegistry.get(conversationId).messages[1] as ChatMessage;
180+
expect(message?.usage?.completion_tokens).toEqual(133);
181+
expect(message?.usage?.prompt_tokens).toEqual(7);
182+
});
183+
});
184+
});
185+
});
186+
187+
export class MockLanguageModelV1 implements LanguageModelV1 {
188+
readonly specificationVersion = 'v1';
189+
readonly provider: LanguageModelV1['provider'];
190+
readonly modelId: LanguageModelV1['modelId'];
191+
192+
supportsUrl: LanguageModelV1['supportsUrl'];
193+
doGenerate: LanguageModelV1['doGenerate'];
194+
doStream: LanguageModelV1['doStream'];
195+
196+
readonly defaultObjectGenerationMode: LanguageModelV1['defaultObjectGenerationMode'];
197+
readonly supportsStructuredOutputs: LanguageModelV1['supportsStructuredOutputs'];
198+
constructor({ doStream = notImplemented }: { doStream?: LanguageModelV1['doStream'] }) {
199+
this.provider = 'mock-model-provider';
200+
this.modelId = 'mock-model-id';
201+
this.doGenerate = notImplemented;
202+
this.doStream = doStream;
203+
}
204+
}
205+
206+
function notImplemented(): never {
207+
throw new Error('Not implemented');
208+
}
209+
210+
export function createTestModel({
211+
stream = convertArrayToReadableStream([]),
212+
rawCall = { rawPrompt: 'prompt', rawSettings: {} },
213+
rawResponse = undefined,
214+
request = undefined,
215+
warnings,
216+
}: {
217+
stream?: ReadableStream<LanguageModelV1StreamPart>;
218+
rawResponse?: { headers: Record<string, string> };
219+
rawCall?: { rawPrompt: string; rawSettings: Record<string, unknown> };
220+
request?: { body: string };
221+
warnings?: LanguageModelV1CallWarning[];
222+
} = {}): LanguageModelV1 {
223+
return new MockLanguageModelV1({
224+
doStream: async () => ({ stream, rawCall, rawResponse, request, warnings }),
225+
});
226+
}
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
/**********************************************************************
2+
* Copyright (C) 2025 Red Hat, Inc.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*
16+
* SPDX-License-Identifier: Apache-2.0
17+
***********************************************************************/
18+
19+
import { streamText } from 'ai';
20+
import type { LanguageModel, CoreMessage, StepResult, StreamTextResult, TextStreamPart, ToolSet } from 'ai';
21+
import type { ModelOptions } from '@shared/models/IModelOptions';
22+
import type {
23+
ChatMessage,
24+
Choice,
25+
ErrorMessage,
26+
Message,
27+
ModelUsage,
28+
PendingChat,
29+
} from '@shared/models/IPlaygroundMessage';
30+
import { isChatMessage } from '@shared/models/IPlaygroundMessage';
31+
import type { ConversationRegistry } from '../../registries/ConversationRegistry';
32+
33+
export function toCoreMessage(...messages: Message[]): CoreMessage[] {
34+
return messages
35+
.filter(m => isChatMessage(m))
36+
.map(
37+
(message: ChatMessage) =>
38+
({
39+
role: message.role,
40+
content: message.content ?? '',
41+
}) as CoreMessage,
42+
);
43+
}
44+
45+
export class AiStreamProcessor<TOOLS extends ToolSet> {
46+
private currentMessageId: string | undefined;
47+
public readonly abortController: AbortController;
48+
49+
constructor(
50+
private conversationId: string,
51+
private conversationRegistry: ConversationRegistry,
52+
) {
53+
this.abortController = new AbortController();
54+
this.abortController.signal.addEventListener('abort', this.onAbort);
55+
}
56+
57+
private onStepFinish = (stepResult: StepResult<TOOLS>): void => {
58+
if (this.currentMessageId !== undefined) {
59+
this.conversationRegistry.setUsage(this.conversationId, this.currentMessageId, {
60+
completion_tokens: stepResult.usage.completionTokens,
61+
prompt_tokens: stepResult.usage.promptTokens,
62+
} as ModelUsage);
63+
// TODO, this doesn't seem very wise (using choices as partial state holder)
64+
// Refactor to use this.conversationRegistry.update instead
65+
this.conversationRegistry.completeMessage(this.conversationId, this.currentMessageId);
66+
}
67+
this.currentMessageId = undefined;
68+
};
69+
70+
private onChunk = ({
71+
chunk,
72+
}: {
73+
chunk: Extract<
74+
TextStreamPart<TOOLS>,
75+
{
76+
type:
77+
| 'text-delta'
78+
| 'reasoning'
79+
| 'source'
80+
| 'tool-call'
81+
| 'tool-call-streaming-start'
82+
| 'tool-call-delta'
83+
| 'tool-result';
84+
}
85+
>;
86+
}): void => {
87+
if (chunk.type !== 'text-delta') {
88+
return;
89+
}
90+
if (this.currentMessageId === undefined) {
91+
this.currentMessageId = this.conversationRegistry.getUniqueId();
92+
this.conversationRegistry.submit(this.conversationId, {
93+
id: this.currentMessageId,
94+
role: 'assistant',
95+
timestamp: Date.now(),
96+
choices: [],
97+
completed: undefined,
98+
} as PendingChat);
99+
}
100+
// TODO, this doesn't seem very wise (using choices as partial state holder)
101+
// Refactor to use this.conversationRegistry.update instead
102+
this.conversationRegistry.appendChoice(this.conversationId, this.currentMessageId, {
103+
content: chunk.textDelta,
104+
} as Choice);
105+
};
106+
107+
private onError = (error: unknown): void => {
108+
if (error instanceof Object && 'error' in error) {
109+
error = error.error;
110+
}
111+
if (error instanceof Error) {
112+
error = error.message;
113+
}
114+
let errorMessage = String(error);
115+
if (errorMessage.endsWith('Please reduce the length of the messages or completion.')) {
116+
errorMessage += ' Note: You should start a new playground.';
117+
}
118+
console.error('Something went wrong while creating model response', errorMessage);
119+
this.conversationRegistry.submit(this.conversationId, {
120+
id: this.conversationRegistry.getUniqueId(),
121+
timestamp: Date.now(),
122+
error: errorMessage,
123+
} as ErrorMessage);
124+
};
125+
126+
private onAbort = (): void => {
127+
// Ensure the last message is marked as complete to allow the user to resume the conversation
128+
if (this.currentMessageId !== undefined) {
129+
// TODO, this doesn't seem very wise (using choices as partial state holder)
130+
// Refactor to use this.conversationRegistry.update instead
131+
this.conversationRegistry.completeMessage(this.conversationId, this.currentMessageId);
132+
}
133+
};
134+
135+
stream = (model: LanguageModel, options?: ModelOptions): StreamTextResult<TOOLS, never> => {
136+
return streamText({
137+
model,
138+
temperature: options?.temperature,
139+
maxTokens: (options?.max_tokens ?? -1) < 1 ? undefined : options?.max_tokens,
140+
topP: options?.top_p,
141+
abortSignal: this.abortController.signal,
142+
messages: toCoreMessage(...this.conversationRegistry.get(this.conversationId).messages),
143+
onStepFinish: this.onStepFinish,
144+
onError: this.onError,
145+
onChunk: this.onChunk,
146+
});
147+
};
148+
}

0 commit comments

Comments
 (0)