diff --git a/src/api/providers/__tests__/openai.test.ts b/src/api/providers/__tests__/openai.test.ts index 950b2165410..a516605ee51 100644 --- a/src/api/providers/__tests__/openai.test.ts +++ b/src/api/providers/__tests__/openai.test.ts @@ -392,4 +392,55 @@ describe("OpenAiHandler", () => { expect(lastCall[0]).not.toHaveProperty("stream_options") }) }) + + describe("Databricks AI Provider", () => { + const baseDatabricksOptions = { + ...mockOptions, + openAiModelId: "databricks-dbrx-instruct", + } + + const databricksUrls = [ + "https://adb-xxxx.azuredatabricks.net/serving-endpoints", + "https://myworkspace.cloud.databricks.com/serving-endpoints/myendpoint", + "https://anotherworkspace.gcp.databricks.com/serving-endpoints/anotherendpoint", + ] + + it.each(databricksUrls)("should initialize with Databricks AI configuration for %s", (url) => { + const options = { ...baseDatabricksOptions, openAiBaseUrl: url } + const databricksHandler = new OpenAiHandler(options) + expect(databricksHandler).toBeInstanceOf(OpenAiHandler) + expect(databricksHandler.getModel().id).toBe(options.openAiModelId) + }) + + it.each(databricksUrls)( + "should exclude stream_options when streaming with Databricks AI for %s", + async (url) => { + const options = { ...baseDatabricksOptions, openAiBaseUrl: url } + const databricksHandler = new OpenAiHandler(options) + const systemPrompt = "You are a helpful assistant." + const messages: Anthropic.Messages.MessageParam[] = [ + { + role: "user", + content: "Hello!", + }, + ] + + const stream = databricksHandler.createMessage(systemPrompt, messages) + await stream.next() // Consume one item to trigger the mock call + + expect(mockCreate).toHaveBeenCalledWith( + expect.objectContaining({ + model: options.openAiModelId, + stream: true, + }), + {}, // Expecting empty options object as second argument + ) + + // Verify stream_options is not present in the last call's arguments + const mockCalls = mockCreate.mock.calls + const lastCallArgs = mockCalls[mockCalls.length - 1][0] + expect(lastCallArgs).not.toHaveProperty("stream_options") + }, + ) + }) }) diff --git a/src/api/providers/openai.ts b/src/api/providers/openai.ts index fc739b31105..b879bd319a6 100644 --- a/src/api/providers/openai.ts +++ b/src/api/providers/openai.ts @@ -138,13 +138,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl } const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + const isDatabricksAI = this._isDatabricksAI(this.options.openAiBaseUrl) const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = { model: modelId, temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0), messages: convertedMessages, stream: true as const, - ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), + ...(isGrokXAI || isDatabricksAI ? {} : { stream_options: { include_usage: true } }), } if (this.options.includeMaxTokens) { requestOptions.max_tokens = modelInfo.maxTokens @@ -268,6 +269,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl) const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl) + const isDatabricksAI = this._isDatabricksAI(this.options.openAiBaseUrl) const stream = await this.client.chat.completions.create( { @@ -280,7 +282,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl ...convertToOpenAiMessages(messages), ], stream: true, - ...(isGrokXAI ? {} : { stream_options: { include_usage: true } }), + ...(isGrokXAI || isDatabricksAI ? {} : { stream_options: { include_usage: true } }), reasoning_effort: this.getModel().info.reasoningEffort, }, methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {}, @@ -346,6 +348,15 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl return urlHost.includes("x.ai") } + private _isDatabricksAI(baseUrl?: string): boolean { + const urlHost = this._getUrlHost(baseUrl) + return ( + urlHost.endsWith(".azuredatabricks.net") || + urlHost.endsWith(".cloud.databricks.com") || + urlHost.endsWith(".gcp.databricks.com") + ) + } + private _isAzureAiInference(baseUrl?: string): boolean { const urlHost = this._getUrlHost(baseUrl) return urlHost.endsWith(".services.ai.azure.com")