Skip to content

Commit dcf7d87

Browse files
committed
chore: add more tests to ensure embedding generation works
1 parent 89a556d commit dcf7d87

File tree

3 files changed

+147
-4
lines changed

3 files changed

+147
-4
lines changed

src/common/search/embeddingsProvider.ts

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@ export type EmbeddingParameters = {
1212
inputType: "query" | "document";
1313
};
1414

15-
interface EmbeddingsProvider<SupportedModels extends string, SupportedEmbeddingParameters extends EmbeddingParameters> {
15+
export interface EmbeddingsProvider<
16+
SupportedModels extends string,
17+
SupportedEmbeddingParameters extends EmbeddingParameters,
18+
> {
1619
embed(
1720
modelId: SupportedModels,
1821
content: EmbeddingsInput[],

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ export class VectorSearchEmbeddingsManager {
3535
constructor(
3636
private readonly config: UserConfig,
3737
private readonly connectionManager: ConnectionManager,
38-
private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map()
38+
private readonly embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]> = new Map(),
39+
private readonly embeddingsProvider: typeof getEmbeddingsProvider = getEmbeddingsProvider
3940
) {
4041
connectionManager.events.on("connection-close", () => {
4142
this.embeddings.clear();
@@ -242,15 +243,21 @@ export class VectorSearchEmbeddingsManager {
242243
);
243244
}
244245

245-
const embeddingsProvider = getEmbeddingsProvider(this.config);
246+
const embeddingsProvider = this.embeddingsProvider(this.config);
246247

247248
if (!embeddingsProvider) {
248249
throw new MongoDBError(ErrorCodes.NoEmbeddingsProviderConfigured, "No embeddings provider configured.");
249250
}
250251

252+
if (this.config.disableEmbeddingsValidation) {
253+
return await embeddingsProvider.embed(embeddingParameters.model, rawValues, {
254+
inputType,
255+
...embeddingParameters,
256+
});
257+
}
258+
251259
const embeddingInfoForCollection = await this.embeddingsForNamespace({ database, collection });
252260
const embeddingInfoForPath = embeddingInfoForCollection.find((definition) => definition.path === path);
253-
254261
if (!embeddingInfoForPath) {
255262
throw new MongoDBError(
256263
ErrorCodes.AtlasVectorSearchIndexNotFound,

tests/unit/common/search/vectorSearchEmbeddingsManager.test.ts

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ import { ConnectionStateConnected } from "../../../../src/common/connectionManag
1313
import type { InsertOneResult } from "mongodb";
1414
import type { DropDatabaseResult } from "@mongosh/service-provider-node-driver/lib/node-driver-service-provider.js";
1515
import EventEmitter from "events";
16+
import {
17+
zVoyageEmbeddingParameters,
18+
type EmbeddingParameters,
19+
type EmbeddingsProvider,
20+
type getEmbeddingsProvider,
21+
} from "../../../../src/common/search/embeddingsProvider.js";
22+
import { ErrorCodes, MongoDBError } from "../../../../src/common/errors.js";
1623

1724
type MockedServiceProvider = NodeDriverServiceProvider & {
1825
getSearchIndexes: MockedFunction<NodeDriverServiceProvider["getSearchIndexes"]>;
@@ -25,6 +32,10 @@ type MockedConnectionManager = ConnectionManager & {
2532
currentConnectionState: ConnectionStateConnected;
2633
};
2734

35+
type MockedEmbeddingsProvider = EmbeddingsProvider<string, EmbeddingParameters> & {
36+
embed: MockedFunction<EmbeddingsProvider<string, EmbeddingParameters>["embed"]>;
37+
};
38+
2839
const database = "my" as const;
2940
const collection = "collection" as const;
3041
const mapKey = `${database}.${collection}` as EmbeddingNamespace;
@@ -78,13 +89,22 @@ describe("VectorSearchEmbeddingsManager", () => {
7889
getURI: () => "mongodb://my-test",
7990
} as unknown as MockedServiceProvider;
8091

92+
const embeddingsProvider: MockedEmbeddingsProvider = {
93+
embed: vi.fn(),
94+
};
95+
96+
const getMockedEmbeddingsProvider: typeof getEmbeddingsProvider = () => {
97+
return embeddingsProvider;
98+
};
99+
81100
const connectionManager: MockedConnectionManager = {
82101
currentConnectionState: new ConnectionStateConnected(provider),
83102
events: eventEmitter,
84103
} as unknown as MockedConnectionManager;
85104

86105
beforeEach(() => {
87106
provider.getSearchIndexes.mockReset();
107+
embeddingsProvider.embed.mockReset();
88108

89109
provider.createSearchIndexes.mockResolvedValue([]);
90110
provider.insertOne.mockResolvedValue({} as unknown as InsertOneResult);
@@ -371,4 +391,117 @@ describe("VectorSearchEmbeddingsManager", () => {
371391
});
372392
});
373393
});
394+
395+
describe("generate embeddings", () => {
396+
const embeddingToGenerate = {
397+
database: "mydb",
398+
collection: "mycoll",
399+
path: "embedding_field",
400+
rawValues: ["oops"],
401+
embeddingParameters: { model: "voyage-3-large", outputDimension: 1024, outputDType: "float" } as const,
402+
inputType: "query" as const,
403+
};
404+
405+
let embeddings: VectorSearchEmbeddingsManager;
406+
407+
beforeEach(() => {
408+
embeddings = new VectorSearchEmbeddingsManager(
409+
embeddingValidationDisabled,
410+
connectionManager,
411+
new Map(),
412+
getMockedEmbeddingsProvider
413+
);
414+
});
415+
416+
describe("when atlas search is not available", () => {
417+
beforeEach(() => {
418+
embeddings = new VectorSearchEmbeddingsManager(
419+
embeddingValidationEnabled,
420+
connectionManager,
421+
new Map(),
422+
getMockedEmbeddingsProvider
423+
);
424+
425+
provider.getSearchIndexes.mockRejectedValue(new Error());
426+
});
427+
428+
it("throws an exception", async () => {
429+
await expect(embeddings.generateEmbeddings(embeddingToGenerate)).rejects.toThrowError();
430+
});
431+
});
432+
433+
describe("when atlas search is available", () => {
434+
describe("when embedding validation is disabled", () => {
435+
beforeEach(() => {
436+
embeddings = new VectorSearchEmbeddingsManager(
437+
embeddingValidationDisabled,
438+
connectionManager,
439+
new Map(),
440+
getMockedEmbeddingsProvider
441+
);
442+
});
443+
444+
describe("when no index is available for path", () => {
445+
it("returns the embeddings as is", async () => {
446+
embeddingsProvider.embed.mockResolvedValue([[0xc0ffee]]);
447+
448+
const [result] = await embeddings.generateEmbeddings(embeddingToGenerate);
449+
expect(result).toEqual([0xc0ffee]);
450+
});
451+
});
452+
});
453+
454+
describe("when embedding validation is enabled", () => {
455+
beforeEach(() => {
456+
embeddings = new VectorSearchEmbeddingsManager(
457+
embeddingValidationEnabled,
458+
connectionManager,
459+
new Map(),
460+
getMockedEmbeddingsProvider
461+
);
462+
});
463+
464+
describe("when no index is available for path", () => {
465+
it("throws an exception", async () => {
466+
await expect(embeddings.generateEmbeddings(embeddingToGenerate)).rejects.toThrowError();
467+
});
468+
});
469+
470+
describe("when index is available on path", () => {
471+
beforeEach(() => {
472+
provider.getSearchIndexes.mockResolvedValue([
473+
{
474+
id: "65e8c766d0450e3e7ab9855f",
475+
name: "vector-search-test",
476+
type: "vectorSearch",
477+
status: "READY",
478+
queryable: true,
479+
latestDefinition: {
480+
fields: [
481+
{
482+
type: "vector",
483+
path: embeddingToGenerate.path,
484+
numDimensions: 1024,
485+
similarity: "euclidean",
486+
},
487+
{ type: "filter", path: "genres" },
488+
{ type: "filter", path: "year" },
489+
],
490+
},
491+
},
492+
]);
493+
});
494+
495+
describe("when embedding validation is disabled", () => {
496+
it("returns the embeddings as is", async () => {
497+
embeddingsProvider.embed.mockResolvedValue([[0xc0ffee]]);
498+
499+
const [result] = await embeddings.generateEmbeddings(embeddingToGenerate);
500+
expect(result).toEqual([0xc0ffee]);
501+
});
502+
});
503+
});
504+
});
505+
});
506+
});
374507
});

0 commit comments

Comments
 (0)