4
4
import { beforeAll , expect , jest , test } from "@jest/globals" ;
5
5
import { Collection , MongoClient } from "mongodb" ;
6
6
import { setTimeout } from "timers/promises" ;
7
- import { OpenAIEmbeddings } from "@langchain/openai" ;
7
+ import { OpenAIEmbeddings , AzureOpenAIEmbeddings } from "@langchain/openai" ;
8
8
import { Document } from "@langchain/core/documents" ;
9
9
// eslint-disable-next-line import/no-extraneous-dependencies
10
10
import { Document as BSONDocument } from "bson" ;
11
11
12
+ import { EmbeddingsInterface } from "@langchain/core/embeddings" ;
12
13
import { MongoDBAtlasVectorSearch } from "../vectorstores.js" ;
13
14
import { isUsingLocalAtlas , uri , waitForIndexToBeQueryable } from "./utils.js" ;
14
15
@@ -102,8 +103,18 @@ class PatchedVectorStore extends MongoDBAtlasVectorSearch {
102
103
}
103
104
}
104
105
106
+ function getEmbeddings ( ) {
107
+ if ( process . env . AZURE_OPENAI_API_KEY ) {
108
+ return new AzureOpenAIEmbeddings ( {
109
+ model : "text-embedding-3-small" ,
110
+ azureOpenAIApiDeploymentName : "openai/deployments/text-embedding-3-small" ,
111
+ } ) ;
112
+ }
113
+ return new OpenAIEmbeddings ( ) ;
114
+ }
115
+
105
116
test ( "MongoDBAtlasVectorSearch with external ids" , async ( ) => {
106
- const vectorStore = new PatchedVectorStore ( new OpenAIEmbeddings ( ) , {
117
+ const vectorStore = new PatchedVectorStore ( getEmbeddings ( ) , {
107
118
collection,
108
119
} ) ;
109
120
@@ -166,7 +177,7 @@ test("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () => {
166
177
const vectorStore = await PatchedVectorStore . fromTexts (
167
178
texts ,
168
179
{ } ,
169
- new OpenAIEmbeddings ( ) ,
180
+ getEmbeddings ( ) ,
170
181
{ collection, indexName : "default" }
171
182
) ;
172
183
@@ -215,7 +226,7 @@ test("MongoDBAtlasVectorSearch with Maximal Marginal Relevance", async () => {
215
226
} ) ;
216
227
217
228
test ( "MongoDBAtlasVectorSearch upsert" , async ( ) => {
218
- const vectorStore = new PatchedVectorStore ( new OpenAIEmbeddings ( ) , {
229
+ const vectorStore = new PatchedVectorStore ( getEmbeddings ( ) , {
219
230
collection,
220
231
} ) ;
221
232
@@ -253,15 +264,15 @@ test("MongoDBAtlasVectorSearch upsert", async () => {
253
264
254
265
describe ( "MongoDBAtlasVectorSearch Constructor" , ( ) => {
255
266
test ( "initializes with minimal configuration" , ( ) => {
256
- const vectorStore = new MongoDBAtlasVectorSearch ( new OpenAIEmbeddings ( ) , {
267
+ const vectorStore = new MongoDBAtlasVectorSearch ( getEmbeddings ( ) , {
257
268
collection,
258
269
} ) ;
259
270
expect ( vectorStore ) . toBeDefined ( ) ;
260
271
} ) ;
261
272
262
273
test ( "initializes with custom index name" , ( ) => {
263
274
const customIndexName = "custom_index" ;
264
- const vectorStore = new MongoDBAtlasVectorSearch ( new OpenAIEmbeddings ( ) , {
275
+ const vectorStore = new MongoDBAtlasVectorSearch ( getEmbeddings ( ) , {
265
276
collection,
266
277
indexName : customIndexName ,
267
278
} ) ;
@@ -271,7 +282,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
271
282
} ) ;
272
283
273
284
test ( "initializes with custom field names" , ( ) => {
274
- const vectorStore = new MongoDBAtlasVectorSearch ( new OpenAIEmbeddings ( ) , {
285
+ const vectorStore = new MongoDBAtlasVectorSearch ( getEmbeddings ( ) , {
275
286
collection,
276
287
textKey : "content" ,
277
288
embeddingKey : "vector" ,
@@ -287,7 +298,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
287
298
} ) ;
288
299
289
300
test ( "initializes AsyncCaller with custom parameters" , ( ) => {
290
- const vectorStore = new MongoDBAtlasVectorSearch ( new OpenAIEmbeddings ( ) , {
301
+ const vectorStore = new MongoDBAtlasVectorSearch ( getEmbeddings ( ) , {
291
302
collection,
292
303
maxConcurrency : 5 ,
293
304
maxRetries : 3 ,
@@ -304,7 +315,7 @@ describe("MongoDBAtlasVectorSearch Constructor", () => {
304
315
} ) ;
305
316
306
317
describe ( "addVectors method" , ( ) => {
307
- let embeddings : OpenAIEmbeddings ;
318
+ let embeddings : EmbeddingsInterface ;
308
319
let vectorStore : PatchedVectorStore ;
309
320
let vectors : number [ ] [ ] ;
310
321
const documents = [
@@ -313,8 +324,8 @@ describe("addVectors method", () => {
313
324
] ;
314
325
315
326
beforeEach ( async ( ) => {
316
- embeddings = new OpenAIEmbeddings ( ) ;
317
- vectorStore = new PatchedVectorStore ( new OpenAIEmbeddings ( ) , {
327
+ embeddings = getEmbeddings ( ) ;
328
+ vectorStore = new PatchedVectorStore ( getEmbeddings ( ) , {
318
329
collection,
319
330
} ) ;
320
331
vectors = await embeddings . embedDocuments ( [ "test 1" , "test 2" ] ) ;
@@ -388,14 +399,14 @@ describe("addVectors method", () => {
388
399
} ) ;
389
400
390
401
describe ( "addDocuments method" , ( ) => {
391
- let embeddings : OpenAIEmbeddings ;
402
+ let embeddings : EmbeddingsInterface ;
392
403
let vectorStore : PatchedVectorStore ;
393
404
const documents = [
394
405
new Document ( { pageContent : "test 1" } ) ,
395
406
new Document ( { pageContent : "test 2" } ) ,
396
407
] ;
397
408
beforeEach ( async ( ) => {
398
- embeddings = new OpenAIEmbeddings ( ) ;
409
+ embeddings = getEmbeddings ( ) ;
399
410
vectorStore = new PatchedVectorStore ( embeddings , {
400
411
collection,
401
412
} ) ;
@@ -474,10 +485,10 @@ describe("addDocuments method", () => {
474
485
} ) ;
475
486
476
487
describe ( "similaritySearchVectorWithScore method" , ( ) => {
477
- let embeddings : OpenAIEmbeddings ;
488
+ let embeddings : EmbeddingsInterface ;
478
489
let vectorStore : PatchedVectorStore ;
479
490
beforeEach ( async ( ) => {
480
- embeddings = new OpenAIEmbeddings ( ) ;
491
+ embeddings = getEmbeddings ( ) ;
481
492
vectorStore = new PatchedVectorStore ( embeddings , {
482
493
collection,
483
494
} ) ;
@@ -717,7 +728,7 @@ describe("delete method", () => {
717
728
} ) ;
718
729
719
730
test ( "removes documents by ids" , async ( ) => {
720
- const vectorStore = new PatchedVectorStore ( new OpenAIEmbeddings ( ) , {
731
+ const vectorStore = new PatchedVectorStore ( getEmbeddings ( ) , {
721
732
collection,
722
733
} ) ;
723
734
@@ -767,7 +778,7 @@ describe("delete method", () => {
767
778
} ) ;
768
779
769
780
test ( "ignores non-existent ids" , async ( ) => {
770
- const vectorStore = new PatchedVectorStore ( new OpenAIEmbeddings ( ) , {
781
+ const vectorStore = new PatchedVectorStore ( getEmbeddings ( ) , {
771
782
collection,
772
783
} ) ;
773
784
@@ -783,10 +794,10 @@ describe("delete method", () => {
783
794
784
795
describe ( "Static Methods" , ( ) => {
785
796
describe ( "fromTexts" , ( ) => {
786
- let embeddings : OpenAIEmbeddings ;
797
+ let embeddings : EmbeddingsInterface ;
787
798
const texts = [ "text1" , "text2" , "text3" ] ;
788
799
beforeEach ( ( ) => {
789
- embeddings = new OpenAIEmbeddings ( ) ;
800
+ embeddings = getEmbeddings ( ) ;
790
801
} ) ;
791
802
792
803
test ( "populates a vector store from strings with a metadata object" , async ( ) => {
@@ -838,7 +849,7 @@ describe("Static Methods", () => {
838
849
] ;
839
850
const store = await MongoDBAtlasVectorSearch . fromDocuments (
840
851
documents ,
841
- new OpenAIEmbeddings ( ) ,
852
+ getEmbeddings ( ) ,
842
853
{ collection }
843
854
) ;
844
855
expect ( store ) . toBeInstanceOf ( MongoDBAtlasVectorSearch ) ;
@@ -850,11 +861,9 @@ describe("Static Methods", () => {
850
861
new Document ( { pageContent : "doc2" , metadata : { source : "source2" } } ) ,
851
862
] ;
852
863
853
- await MongoDBAtlasVectorSearch . fromDocuments (
854
- documents ,
855
- new OpenAIEmbeddings ( ) ,
856
- { collection }
857
- ) ;
864
+ await MongoDBAtlasVectorSearch . fromDocuments ( documents , getEmbeddings ( ) , {
865
+ collection,
866
+ } ) ;
858
867
859
868
const results = await collection
860
869
. find ( { } , { projection : { _id : 0 , text : 1 , embedding : 1 } } )
@@ -873,11 +882,10 @@ describe("Static Methods", () => {
873
882
new Document ( { pageContent : "doc2" , metadata : { source : "source2" } } ) ,
874
883
] ;
875
884
876
- await MongoDBAtlasVectorSearch . fromDocuments (
877
- documents ,
878
- new OpenAIEmbeddings ( ) ,
879
- { collection, ids : [ "custom1" , "custom2" ] }
880
- ) ;
885
+ await MongoDBAtlasVectorSearch . fromDocuments ( documents , getEmbeddings ( ) , {
886
+ collection,
887
+ ids : [ "custom1" , "custom2" ] ,
888
+ } ) ;
881
889
882
890
const results = await collection
883
891
. find ( { } , { projection : { _id : 1 , text : 1 } } )
@@ -896,7 +904,7 @@ describe("Static Methods", () => {
896
904
new Document ( { pageContent : "doc1" } ) ,
897
905
new Document ( { pageContent : "doc2" } ) ,
898
906
] ,
899
- new OpenAIEmbeddings ( ) ,
907
+ getEmbeddings ( ) ,
900
908
{ collection, ids : [ "id1" , "id2" ] }
901
909
) ;
902
910
@@ -905,7 +913,7 @@ describe("Static Methods", () => {
905
913
new Document ( { pageContent : "updated 1" } ) ,
906
914
new Document ( { pageContent : "updated 2" } ) ,
907
915
] ,
908
- new OpenAIEmbeddings ( ) ,
916
+ getEmbeddings ( ) ,
909
917
{ collection, ids : [ "id1" , "id2" ] }
910
918
) ;
911
919
0 commit comments