Skip to content

Commit abcf613

Browse files
author
Duc Nguyen
committed
Add URL base check for Databricks instance on AWS and GCP
1 parent 784bded commit abcf613

File tree

2 files changed

+43
-29
lines changed

2 files changed

+43
-29
lines changed

src/api/providers/__tests__/openai.test.ts

Lines changed: 38 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -394,43 +394,53 @@ describe("OpenAiHandler", () => {
394394
})
395395

396396
describe("Databricks AI Provider", () => {
397-
const databricksOptions = {
397+
const baseDatabricksOptions = {
398398
...mockOptions,
399-
openAiBaseUrl: "https://adb-xxxx.azuredatabricks.net/serving-endpoints",
400399
openAiModelId: "databricks-dbrx-instruct",
401400
}
402401

403-
it("should initialize with Databricks AI configuration", () => {
404-
const databricksHandler = new OpenAiHandler(databricksOptions)
402+
const databricksUrls = [
403+
"https://adb-xxxx.azuredatabricks.net/serving-endpoints",
404+
"https://myworkspace.cloud.databricks.com/serving-endpoints/myendpoint",
405+
"https://anotherworkspace.gcp.databricks.com/serving-endpoints/anotherendpoint",
406+
]
407+
408+
it.each(databricksUrls)("should initialize with Databricks AI configuration for %s", (url) => {
409+
const options = { ...baseDatabricksOptions, openAiBaseUrl: url }
410+
const databricksHandler = new OpenAiHandler(options)
405411
expect(databricksHandler).toBeInstanceOf(OpenAiHandler)
406-
expect(databricksHandler.getModel().id).toBe(databricksOptions.openAiModelId)
412+
expect(databricksHandler.getModel().id).toBe(options.openAiModelId)
407413
})
408414

409-
it("should exclude stream_options when streaming with Databricks AI", async () => {
410-
const databricksHandler = new OpenAiHandler(databricksOptions)
411-
const systemPrompt = "You are a helpful assistant."
412-
const messages: Anthropic.Messages.MessageParam[] = [
413-
{
414-
role: "user",
415-
content: "Hello!",
416-
},
417-
]
415+
it.each(databricksUrls)(
416+
"should exclude stream_options when streaming with Databricks AI for %s",
417+
async (url) => {
418+
const options = { ...baseDatabricksOptions, openAiBaseUrl: url }
419+
const databricksHandler = new OpenAiHandler(options)
420+
const systemPrompt = "You are a helpful assistant."
421+
const messages: Anthropic.Messages.MessageParam[] = [
422+
{
423+
role: "user",
424+
content: "Hello!",
425+
},
426+
]
418427

419-
const stream = databricksHandler.createMessage(systemPrompt, messages)
420-
await stream.next() // Consume one item to trigger the mock call
428+
const stream = databricksHandler.createMessage(systemPrompt, messages)
429+
await stream.next() // Consume one item to trigger the mock call
421430

422-
expect(mockCreate).toHaveBeenCalledWith(
423-
expect.objectContaining({
424-
model: databricksOptions.openAiModelId,
425-
stream: true,
426-
}),
427-
{}, // Expecting empty options object as second argument
428-
)
431+
expect(mockCreate).toHaveBeenCalledWith(
432+
expect.objectContaining({
433+
model: options.openAiModelId,
434+
stream: true,
435+
}),
436+
{}, // Expecting empty options object as second argument
437+
)
429438

430-
// Verify stream_options is not present in the last call's arguments
431-
const mockCalls = mockCreate.mock.calls
432-
const lastCallArgs = mockCalls[mockCalls.length - 1][0]
433-
expect(lastCallArgs).not.toHaveProperty("stream_options")
434-
})
439+
// Verify stream_options is not present in the last call's arguments
440+
const mockCalls = mockCreate.mock.calls
441+
const lastCallArgs = mockCalls[mockCalls.length - 1][0]
442+
expect(lastCallArgs).not.toHaveProperty("stream_options")
443+
},
444+
)
435445
})
436446
})

src/api/providers/openai.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,11 @@ export class OpenAiHandler extends BaseProvider implements SingleCompletionHandl
350350

351351
private _isDatabricksAI(baseUrl?: string): boolean {
352352
const urlHost = this._getUrlHost(baseUrl)
353-
return urlHost.includes(".azuredatabricks.net")
353+
return (
354+
urlHost.endsWith(".azuredatabricks.net") ||
355+
urlHost.endsWith(".cloud.databricks.com") ||
356+
urlHost.endsWith(".gcp.databricks.com")
357+
)
354358
}
355359

356360
private _isAzureAiInference(baseUrl?: string): boolean {

0 commit comments

Comments
 (0)