Skip to content

Commit f3a9881

Browse files
authored
test(mongodb): use azure openid when available in tests (#8588)
2 parents fa77039 + b8211a5 commit f3a9881

File tree

1 file changed

+40
-32
lines changed

1 file changed

+40
-32
lines changed

libs/langchain-mongodb/src/tests/vectorstores.int.test.ts

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
import { beforeAll, expect, jest, test } from "@jest/globals";
55
import { Collection, MongoClient } from "mongodb";
66
import { setTimeout } from "timers/promises";
7-
import { OpenAIEmbeddings } from "@langchain/openai";
7+
import { OpenAIEmbeddings, AzureOpenAIEmbeddings } from "@langchain/openai";
88
import { Document } from "@langchain/core/documents";
99
// eslint-disable-next-line import/no-extraneous-dependencies
1010
import { Document as BSONDocument } from "bson";
1111

12+
import { EmbeddingsInterface } from "@langchain/core/embeddings";
1213
import { MongoDBAtlasVectorSearch } from "../vectorstores.js";
1314
import { isUsingLocalAtlas, uri, waitForIndexToBeQueryable } from "./utils.js";
1415

@@ -102,8 +103,18 @@ class PatchedVectorStore extends MongoDBAtlasVectorSearch {
102103
}
103104
}
104105

106+
function getEmbeddings() {
107+
if (process.env.AZURE_OPENAI_API_KEY) {
108+
return new AzureOpenAIEmbeddings({
109+
model: "text-embedding-3-small",
110+
azureOpenAIApiDeploymentName: "openai/deployments/text-embedding-3-small",
111+
});
112+
}
113+
return new OpenAIEmbeddings();
114+
}
115+
105116
test("MongoDBAtlasVectorSearch with external ids", async () => {
106-
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
117+
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
107118
collection,
108119
});
109120

@@ -166,7 +177,7 @@ test("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () => {
166177
const vectorStore = await PatchedVectorStore.fromTexts(
167178
texts,
168179
{},
169-
new OpenAIEmbeddings(),
180+
getEmbeddings(),
170181
{ collection, indexName: "default" }
171182
);
172183

@@ -215,7 +226,7 @@ test("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () => {
215226
});
216227

217228
test("MongoDBAtlasVectorSearch upsert", async () => {
218-
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
229+
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
219230
collection,
220231
});
221232

@@ -253,15 +264,15 @@ test("MongoDBAtlasVectorSearch upsert", async () => {
253264

254265
describe("MongoDBAtlasVectorSearch Constructor", () => {
255266
test("initializes with minimal configuration", () => {
256-
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
267+
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
257268
collection,
258269
});
259270
expect(vectorStore).toBeDefined();
260271
});
261272

262273
test("initializes with custom index name", () => {
263274
const customIndexName = "custom_index";
264-
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
275+
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
265276
collection,
266277
indexName: customIndexName,
267278
});
@@ -271,7 +282,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
271282
});
272283

273284
test("initializes with custom field names", () => {
274-
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
285+
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
275286
collection,
276287
textKey: "content",
277288
embeddingKey: "vector",
@@ -287,7 +298,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
287298
});
288299

289300
test("initializes AsyncCaller with custom parameters", () => {
290-
const vectorStore = new MongoDBAtlasVectorSearch(new OpenAIEmbeddings(), {
301+
const vectorStore = new MongoDBAtlasVectorSearch(getEmbeddings(), {
291302
collection,
292303
maxConcurrency: 5,
293304
maxRetries: 3,
@@ -304,7 +315,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
304315
});
305316

306317
describe("addVectors method", () => {
307-
let embeddings: OpenAIEmbeddings;
318+
let embeddings: EmbeddingsInterface;
308319
let vectorStore: PatchedVectorStore;
309320
let vectors: number[][];
310321
const documents = [
@@ -313,8 +324,8 @@ describe("addVectors method", () => {
313324
];
314325

315326
beforeEach(async () => {
316-
embeddings = new OpenAIEmbeddings();
317-
vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
327+
embeddings = getEmbeddings();
328+
vectorStore = new PatchedVectorStore(getEmbeddings(), {
318329
collection,
319330
});
320331
vectors = await embeddings.embedDocuments(["test 1", "test 2"]);
@@ -388,14 +399,14 @@ describe("addVectors method", () => {
388399
});
389400

390401
describe("addDocuments method", () => {
391-
let embeddings: OpenAIEmbeddings;
402+
let embeddings: EmbeddingsInterface;
392403
let vectorStore: PatchedVectorStore;
393404
const documents = [
394405
new Document({ pageContent: "test 1" }),
395406
new Document({ pageContent: "test 2" }),
396407
];
397408
beforeEach(async () => {
398-
embeddings = new OpenAIEmbeddings();
409+
embeddings = getEmbeddings();
399410
vectorStore = new PatchedVectorStore(embeddings, {
400411
collection,
401412
});
@@ -474,10 +485,10 @@ describe("addDocuments method", () => {
474485
});
475486

476487
describe("similaritySearchVectorWithScore method", () => {
477-
let embeddings: OpenAIEmbeddings;
488+
let embeddings: EmbeddingsInterface;
478489
let vectorStore: PatchedVectorStore;
479490
beforeEach(async () => {
480-
embeddings = new OpenAIEmbeddings();
491+
embeddings = getEmbeddings();
481492
vectorStore = new PatchedVectorStore(embeddings, {
482493
collection,
483494
});
@@ -717,7 +728,7 @@ describe("delete method", () => {
717728
});
718729

719730
test("removes documents by ids", async () => {
720-
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
731+
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
721732
collection,
722733
});
723734

@@ -767,7 +778,7 @@ describe("delete method", () => {
767778
});
768779

769780
test("ignores non-existent ids", async () => {
770-
const vectorStore = new PatchedVectorStore(new OpenAIEmbeddings(), {
781+
const vectorStore = new PatchedVectorStore(getEmbeddings(), {
771782
collection,
772783
});
773784

@@ -783,10 +794,10 @@ describe("delete method", () => {
783794

784795
describe("Static Methods", () => {
785796
describe("fromTexts", () => {
786-
let embeddings: OpenAIEmbeddings;
797+
let embeddings: EmbeddingsInterface;
787798
const texts = ["text1", "text2", "text3"];
788799
beforeEach(() => {
789-
embeddings = new OpenAIEmbeddings();
800+
embeddings = getEmbeddings();
790801
});
791802

792803
test("populates a vector store from strings with a metadata object", async () => {
@@ -838,7 +849,7 @@ describe("Static Methods", () => {
838849
];
839850
const store = await MongoDBAtlasVectorSearch.fromDocuments(
840851
documents,
841-
new OpenAIEmbeddings(),
852+
getEmbeddings(),
842853
{ collection }
843854
);
844855
expect(store).toBeInstanceOf(MongoDBAtlasVectorSearch);
@@ -850,11 +861,9 @@ describe("Static Methods", () => {
850861
new Document({ pageContent: "doc2", metadata: { source: "source2" } }),
851862
];
852863

853-
await MongoDBAtlasVectorSearch.fromDocuments(
854-
documents,
855-
new OpenAIEmbeddings(),
856-
{ collection }
857-
);
864+
await MongoDBAtlasVectorSearch.fromDocuments(documents, getEmbeddings(), {
865+
collection,
866+
});
858867

859868
const results = await collection
860869
.find({}, { projection: { _id: 0, text: 1, embedding: 1 } })
@@ -873,11 +882,10 @@ describe("Static Methods", () => {
873882
new Document({ pageContent: "doc2", metadata: { source: "source2" } }),
874883
];
875884

876-
await MongoDBAtlasVectorSearch.fromDocuments(
877-
documents,
878-
new OpenAIEmbeddings(),
879-
{ collection, ids: ["custom1", "custom2"] }
880-
);
885+
await MongoDBAtlasVectorSearch.fromDocuments(documents, getEmbeddings(), {
886+
collection,
887+
ids: ["custom1", "custom2"],
888+
});
881889

882890
const results = await collection
883891
.find({}, { projection: { _id: 1, text: 1 } })
@@ -896,7 +904,7 @@ describe("Static Methods", () => {
896904
new Document({ pageContent: "doc1" }),
897905
new Document({ pageContent: "doc2" }),
898906
],
899-
new OpenAIEmbeddings(),
907+
getEmbeddings(),
900908
{ collection, ids: ["id1", "id2"] }
901909
);
902910

@@ -905,7 +913,7 @@ describe("Static Methods", () => {
905913
new Document({ pageContent: "updated 1" }),
906914
new Document({ pageContent: "updated 2" }),
907915
],
908-
new OpenAIEmbeddings(),
916+
getEmbeddings(),
909917
{ collection, ids: ["id1", "id2"] }
910918
);
911919

0 commit comments

Comments
 (0)