Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions src/api/providers/__tests__/openai.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -392,4 +392,55 @@ describe("OpenAiHandler", () => {
expect(lastCall[0]).not.toHaveProperty("stream_options")
})
})

describe("Databricks AI Provider", () => {
const baseDatabricksOptions = {
...mockOptions,
openAiModelId: "databricks-dbrx-instruct",
}

const databricksUrls = [
"https://adb-xxxx.azuredatabricks.net/serving-endpoints",
"https://myworkspace.cloud.databricks.com/serving-endpoints/myendpoint",
"https://anotherworkspace.gcp.databricks.com/serving-endpoints/anotherendpoint",
]

it.each(databricksUrls)("should initialize with Databricks AI configuration for %s", (url) => {
const options = { ...baseDatabricksOptions, openAiBaseUrl: url }
const databricksHandler = new OpenAiHandler(options)
expect(databricksHandler).toBeInstanceOf(OpenAiHandler)
expect(databricksHandler.getModel().id).toBe(options.openAiModelId)
})

it.each(databricksUrls)(
"should exclude stream_options when streaming with Databricks AI for %s",
async (url) => {
const options = { ...baseDatabricksOptions, openAiBaseUrl: url }
const databricksHandler = new OpenAiHandler(options)
const systemPrompt = "You are a helpful assistant."
const messages: Anthropic.Messages.MessageParam[] = [
{
role: "user",
content: "Hello!",
},
]

const stream = databricksHandler.createMessage(systemPrompt, messages)
await stream.next() // Consume one item to trigger the mock call

expect(mockCreate).toHaveBeenCalledWith(
expect.objectContaining({
model: options.openAiModelId,
stream: true,
}),
{}, // Expecting empty options object as second argument
)

// Verify stream_options is not present in the last call's arguments
const mockCalls = mockCreate.mock.calls
const lastCallArgs = mockCalls[mockCalls.length - 1][0]
expect(lastCallArgs).not.toHaveProperty("stream_options")
},
)
})
})
15 changes: 13 additions & 2 deletions src/api/providers/openai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
}

const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
const isDatabricksAI = this._isDatabricksAI(this.options.openAiBaseUrl)

const requestOptions: OpenAI.Chat.Completions.ChatCompletionCreateParamsStreaming = {
model: modelId,
temperature: this.options.modelTemperature ?? (deepseekReasoner ? DEEP_SEEK_DEFAULT_TEMPERATURE : 0),
messages: convertedMessages,
stream: true as const,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
...(isGrokXAI || isDatabricksAI ? {} : { stream_options: { include_usage: true } }),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider extracting the (isGrokXAI || isDatabricksAI) check into a helper function for clarity and maintainability if similar conditions are repeated in the future.

}
if (this.options.includeMaxTokens) {
requestOptions.max_tokens = modelInfo.maxTokens
Expand Down Expand Up @@ -268,6 +269,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
const methodIsAzureAiInference = this._isAzureAiInference(this.options.openAiBaseUrl)

const isGrokXAI = this._isGrokXAI(this.options.openAiBaseUrl)
const isDatabricksAI = this._isDatabricksAI(this.options.openAiBaseUrl)

const stream = await this.client.chat.completions.create(
{
Expand All @@ -280,7 +282,7 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
...convertToOpenAiMessages(messages),
],
stream: true,
...(isGrokXAI ? {} : { stream_options: { include_usage: true } }),
...(isGrokXAI || isDatabricksAI ? {} : { stream_options: { include_usage: true } }),
reasoning_effort: this.getModel().info.reasoningEffort,
},
methodIsAzureAiInference ? { path: AZURE_AI_INFERENCE_PATH } : {},
Expand Down Expand Up @@ -346,6 +348,15 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
return urlHost.includes("x.ai")
}

private _isDatabricksAI(baseUrl?: string): boolean {
const urlHost = this._getUrlHost(baseUrl)
return (
urlHost.endsWith(".azuredatabricks.net") ||
urlHost.endsWith(".cloud.databricks.com") ||
urlHost.endsWith(".gcp.databricks.com")
)
}

private _isAzureAiInference(baseUrl?: string): boolean {
const urlHost = this._getUrlHost(baseUrl)
return urlHost.endsWith(".services.ai.azure.com")
Expand Down