Skip to content

Commit 784bded

Browse files
author
Duc Nguyen
committed
Make Databricks models streaming in OpenAI Compatible provider
1 parent b8bf634 commit 784bded

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

src/api/providers/__tests__/openai.test.ts

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,4 +392,45 @@ describe("OpenAiHandler", () => {
392392
expect(lastCall[0]).not.toHaveProperty("stream_options")
393393
})
394394
})
395+
396+
describe("Databricks AI Provider", () => {
397+
const databricksOptions = {
398+
...mockOptions,
399+
openAiBaseUrl: "https://adb-xxxx.azuredatabricks.net/serving-endpoints",
400+
openAiModelId: "databricks-dbrx-instruct",
401+
}
402+
403+
it("should initialize with Databricks AI configuration", () => {
404+
const databricksHandler = new OpenAiHandler(databricksOptions)
405+
expect(databricksHandler).toBeInstanceOf(OpenAiHandler)
406+
expect(databricksHandler.getModel().id).toBe(databricksOptions.openAiModelId)
407+
})
408+
409+
it("should exclude stream_options when streaming with Databricks AI", async () => {
410+
const databricksHandler = new OpenAiHandler(databricksOptions)
411+
const systemPrompt = "You are a helpful assistant."
412+
const messages: Anthropic.Messages.MessageParam[] = [
413+
{
414+
role: "user",
415+
content: "Hello!",
416+
},
417+
]
418+
419+
const stream = databricksHandler.createMessage(systemPrompt, messages)
420+
await stream.next() // Consume one item to trigger the mock call
421+
422+
expect(mockCreate).toHaveBeenCalledWith(
423+
expect.objectContaining({
424+
model: databricksOptions.openAiModelId,
425+
stream: true,
426+
}),
427+
{}, // Expecting empty options object as second argument
428+
)
429+
430+
// Verify stream_options is not present in the last call's arguments
431+
const mockCalls = mockCreate.mock.calls
432+
const lastCallArgs = mockCalls[mockCalls.length - 1][0]
433+
expect(lastCallArgs).not.toHaveProperty("stream_options")
434+
})
435+
})
395436
})

src/api/providers/openai.ts

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,13 +138,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
138138
}
139139

140140
const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
141+
const isDatabricksAI = this._isDatabricksAI(this.options.openAiBaseUrl)
141142

142143
const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
143144
model: modelId,
144145
temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
145146
messages: convertedMessages,
146147
stream: true as const,
147-
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
148+
...(isGrokXAI || isDatabricksAI ? {} : { stream_options: { include_usage: true } }),
148149
}
149150
if (this.options.includeMaxTokens) {
150151
requestOptions.max_tokens = modelInfo.maxTokens
@@ -268,6 +269,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
268269
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)
269270

270271
const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
272+
const isDatabricksAI = this._isDatabricksAI(this.options.openAiBaseUrl)
271273

272274
const stream = await this.client.chat.completions.create(
273275
{
@@ -280,7 +282,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
280282
...convertToOpenAiMessages(messages),
281283
],
282284
stream: true,
283-
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
285+
...(isGrokXAI || isDatabricksAI ? {} : { stream_options: { include_usage: true } }),
284286
reasoning_effort: this.getModel().info.reasoningEffort,
285287
},
286288
methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
@@ -346,6 +348,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
346348
return urlHost.includes("x.ai")
347349
}
348350

351+
private _isDatabricksAI(baseUrl?: string): boolean {
352+
const urlHost = this._getUrlHost(baseUrl)
353+
return urlHost.includes(".azuredatabricks.net")
354+
}
355+
349356
private _isAzureAiInference(baseUrl?: string): boolean {
350357
const urlHost = this._getUrlHost(baseUrl)
351358
return urlHost.endsWith(".services.ai.azure.com")

0 commit comments

Comments
 (0)