@@ -13,6 +13,13 @@ import { ConnectionStateConnected } from "../../../../src/common/connectionManag
1313import type { InsertOneResult } from "mongodb" ;
1414import type { DropDatabaseResult } from "@mongosh/service-provider-node-driver/lib/node-driver-service-provider.js" ;
1515import EventEmitter from "events" ;
16+ import {
17+ zVoyageEmbeddingParameters ,
18+ type EmbeddingParameters ,
19+ type EmbeddingsProvider ,
20+ type getEmbeddingsProvider ,
21+ } from "../../../../src/common/search/embeddingsProvider.js" ;
22+ import { ErrorCodes , MongoDBError } from "../../../../src/common/errors.js" ;
1623
1724type MockedServiceProvider = NodeDriverServiceProvider & {
1825 getSearchIndexes : MockedFunction < NodeDriverServiceProvider [ "getSearchIndexes" ] > ;
@@ -25,6 +32,10 @@ type MockedConnectionManager = ConnectionManager & {
2532 currentConnectionState : ConnectionStateConnected ;
2633} ;
2734
35+ type MockedEmbeddingsProvider = EmbeddingsProvider < string , EmbeddingParameters > & {
36+ embed : MockedFunction < EmbeddingsProvider < string , EmbeddingParameters > [ "embed" ] > ;
37+ } ;
38+
2839const database = "my" as const ;
2940const collection = "collection" as const ;
3041const mapKey = `${ database } .${ collection } ` as EmbeddingNamespace ;
@@ -78,13 +89,22 @@ describe("VectorSearchEmbeddingsManager", () => {
7889 getURI : ( ) => "mongodb://my-test" ,
7990 } as unknown as MockedServiceProvider ;
8091
92+ const embeddingsProvider : MockedEmbeddingsProvider = {
93+ embed : vi . fn ( ) ,
94+ } ;
95+
96+ const getMockedEmbeddingsProvider : typeof getEmbeddingsProvider = ( ) => {
97+ return embeddingsProvider ;
98+ } ;
99+
81100 const connectionManager : MockedConnectionManager = {
82101 currentConnectionState : new ConnectionStateConnected ( provider ) ,
83102 events : eventEmitter ,
84103 } as unknown as MockedConnectionManager ;
85104
86105 beforeEach ( ( ) => {
87106 provider . getSearchIndexes . mockReset ( ) ;
107+ embeddingsProvider . embed . mockReset ( ) ;
88108
89109 provider . createSearchIndexes . mockResolvedValue ( [ ] ) ;
90110 provider . insertOne . mockResolvedValue ( { } as unknown as InsertOneResult ) ;
@@ -371,4 +391,117 @@ describe("VectorSearchEmbeddingsManager", () => {
371391 } ) ;
372392 } ) ;
373393 } ) ;
394+
395+ describe ( "generate embeddings" , ( ) => {
396+ const embeddingToGenerate = {
397+ database : "mydb" ,
398+ collection : "mycoll" ,
399+ path : "embedding_field" ,
400+ rawValues : [ "oops" ] ,
401+ embeddingParameters : { model : "voyage-3-large" , outputDimension : 1024 , outputDType : "float" } as const ,
402+ inputType : "query" as const ,
403+ } ;
404+
405+ let embeddings : VectorSearchEmbeddingsManager ;
406+
407+ beforeEach ( ( ) => {
408+ embeddings = new VectorSearchEmbeddingsManager (
409+ embeddingValidationDisabled ,
410+ connectionManager ,
411+ new Map ( ) ,
412+ getMockedEmbeddingsProvider
413+ ) ;
414+ } ) ;
415+
416+ describe ( "when atlas search is not available" , ( ) => {
417+ beforeEach ( ( ) => {
418+ embeddings = new VectorSearchEmbeddingsManager (
419+ embeddingValidationEnabled ,
420+ connectionManager ,
421+ new Map ( ) ,
422+ getMockedEmbeddingsProvider
423+ ) ;
424+
425+ provider . getSearchIndexes . mockRejectedValue ( new Error ( ) ) ;
426+ } ) ;
427+
428+ it ( "throws an exception" , async ( ) => {
429+ await expect ( embeddings . generateEmbeddings ( embeddingToGenerate ) ) . rejects . toThrowError ( ) ;
430+ } ) ;
431+ } ) ;
432+
433+ describe ( "when atlas search is available" , ( ) => {
434+ describe ( "when embedding validation is disabled" , ( ) => {
435+ beforeEach ( ( ) => {
436+ embeddings = new VectorSearchEmbeddingsManager (
437+ embeddingValidationDisabled ,
438+ connectionManager ,
439+ new Map ( ) ,
440+ getMockedEmbeddingsProvider
441+ ) ;
442+ } ) ;
443+
444+ describe ( "when no index is available for path" , ( ) => {
445+ it ( "returns the embeddings as is" , async ( ) => {
446+ embeddingsProvider . embed . mockResolvedValue ( [ [ 0xc0ffee ] ] ) ;
447+
448+ const [ result ] = await embeddings . generateEmbeddings ( embeddingToGenerate ) ;
449+ expect ( result ) . toEqual ( [ 0xc0ffee ] ) ;
450+ } ) ;
451+ } ) ;
452+ } ) ;
453+
454+ describe ( "when embedding validation is enabled" , ( ) => {
455+ beforeEach ( ( ) => {
456+ embeddings = new VectorSearchEmbeddingsManager (
457+ embeddingValidationEnabled ,
458+ connectionManager ,
459+ new Map ( ) ,
460+ getMockedEmbeddingsProvider
461+ ) ;
462+ } ) ;
463+
464+ describe ( "when no index is available for path" , ( ) => {
465+ it ( "throws an exception" , async ( ) => {
466+ await expect ( embeddings . generateEmbeddings ( embeddingToGenerate ) ) . rejects . toThrowError ( ) ;
467+ } ) ;
468+ } ) ;
469+
470+ describe ( "when index is available on path" , ( ) => {
471+ beforeEach ( ( ) => {
472+ provider . getSearchIndexes . mockResolvedValue ( [
473+ {
474+ id : "65e8c766d0450e3e7ab9855f" ,
475+ name : "vector-search-test" ,
476+ type : "vectorSearch" ,
477+ status : "READY" ,
478+ queryable : true ,
479+ latestDefinition : {
480+ fields : [
481+ {
482+ type : "vector" ,
483+ path : embeddingToGenerate . path ,
484+ numDimensions : 1024 ,
485+ similarity : "euclidean" ,
486+ } ,
487+ { type : "filter" , path : "genres" } ,
488+ { type : "filter" , path : "year" } ,
489+ ] ,
490+ } ,
491+ } ,
492+ ] ) ;
493+ } ) ;
494+
495+ describe ( "when embedding validation is disabled" , ( ) => {
496+ it ( "returns the embeddings as is" , async ( ) => {
497+ embeddingsProvider . embed . mockResolvedValue ( [ [ 0xc0ffee ] ] ) ;
498+
499+ const [ result ] = await embeddings . generateEmbeddings ( embeddingToGenerate ) ;
500+ expect ( result ) . toEqual ( [ 0xc0ffee ] ) ;
501+ } ) ;
502+ } ) ;
503+ } ) ;
504+ } ) ;
505+ } ) ;
506+ } ) ;
374507} ) ;
0 commit comments