Skip to content

Commit c658a60

Browse files
chore: add AzureOpenAI model in the model list
1 parent 7747060 commit c658a60

File tree

6 files changed

+63
-5
lines changed

6 files changed

+63
-5
lines changed

package-lock.json

Lines changed: 36 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
},
3535
"license": "Apache-2.0",
3636
"devDependencies": {
37+
"@ai-sdk/azure": "^1.3.23",
3738
"@eslint/js": "^9.24.0",
3839
"@himanshusinghs/google": "^1.2.11",
3940
"@jest/globals": "^30.0.0",

tests/accuracy/list-collections.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { AccuracyTestConfig } from "./sdk/describe-accuracy-tests.js";
44

55
function describeListCollectionsAccuracyTests(prompt: string): AccuracyTestConfig {
66
return {
7-
systemPrompt: "Assume that you're already connected.",
7+
injectConnectedAssumption: true,
88
prompt: prompt,
99
mockedTools: {
1010
"list-collections": function listCollections() {

tests/accuracy/list-databases.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import { AccuracyTestConfig } from "./sdk/describe-accuracy-tests.js";
44

55
function describeListDatabasesAccuracyTests(prompt: string): AccuracyTestConfig {
66
return {
7-
systemPrompt: "Assume that you're already connected.",
7+
injectConnectedAssumption: true,
88
prompt: prompt,
99
mockedTools: {
1010
"list-databases": function listDatabases() {

tests/accuracy/sdk/describe-accuracy-tests.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import { appendAccuracySnapshot } from "./accuracy-snapshot.js";
77

88
export interface AccuracyTestConfig {
99
systemPrompt?: string;
10+
injectConnectedAssumption?: boolean;
1011
prompt: string;
1112
expectedToolCalls: ExpectedToolCall[];
1213
mockedTools: MockedTools;
@@ -44,7 +45,10 @@ export function describeAccuracyTests(
4445

4546
eachTest("$prompt", async function (testConfig) {
4647
testTools.mockTools(testConfig.mockedTools);
47-
const conversation = await agent.prompt(testConfig.prompt, model, testTools.vercelAiTools());
48+
const promptForModel = testConfig.injectConnectedAssumption
49+
? [testConfig.prompt, "(Assume that you are already connected to a MongoDB cluster!)"].join(" ")
50+
: testConfig.prompt;
51+
const conversation = await agent.prompt(promptForModel, model, testTools.vercelAiTools());
4852
const toolCalls = testTools.getToolCalls();
4953
const toolCallingAccuracy = toolCallingAccuracyScorer(testConfig.expectedToolCalls, toolCalls);
5054
const parameterMatchingAccuracy = parameterMatchingAccuracyScorer(testConfig.expectedToolCalls, toolCalls);

tests/accuracy/sdk/models.ts

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { LanguageModelV1 } from "ai";
22
import { createGoogleGenerativeAI } from "@himanshusinghs/google";
3+
import { createAzure } from "@ai-sdk/azure";
34
import { ollama } from "ollama-ai-provider";
45

56
export interface Model<P extends LanguageModelV1 = LanguageModelV1> {
@@ -8,6 +9,22 @@ export interface Model<P extends LanguageModelV1 = LanguageModelV1> {
89
getModel(): P;
910
}
1011

12+
export class OpenAIModel implements Model {
13+
constructor(readonly modelName: string) {}
14+
15+
isAvailable(): boolean {
16+
return !!process.env.MDB_AZURE_OPEN_AI_API_KEY && !!process.env.MDB_AZURE_OPEN_AI_API_URL;
17+
}
18+
19+
getModel() {
20+
return createAzure({
21+
baseURL: process.env.MDB_AZURE_OPEN_AI_API_URL,
22+
apiKey: process.env.MDB_AZURE_OPEN_AI_API_KEY,
23+
apiVersion: "2024-12-01-preview",
24+
})(this.modelName);
25+
}
26+
}
27+
1128
export class GeminiModel implements Model {
1229
constructor(readonly modelName: string) {}
1330

@@ -35,9 +52,9 @@ export class OllamaModel implements Model {
3552
}
3653

3754
const ALL_TESTABLE_MODELS = [
38-
new GeminiModel("gemini-1.5-flash"),
3955
new GeminiModel("gemini-2.0-flash"),
40-
new OllamaModel("qwen3:1.7b"),
56+
new OpenAIModel("gpt-4o"),
57+
// new OllamaModel("qwen3:1.7b"),
4158
];
4259

4360
export type TestableModels = ReturnType<typeof getAvailableModels>;

0 commit comments

Comments
 (0)