Skip to content

Commit a098234

Browse files
committed
chore: draft integration of embeddings with the aggregate tool
1 parent c7d8e59 commit a098234

File tree

4 files changed

+240
-41
lines changed

4 files changed

+240
-41
lines changed

src/common/errors.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ export enum ErrorCodes {
44
ForbiddenCollscan = 1_000_002,
55
ForbiddenWriteOperation = 1_000_003,
66
AtlasSearchNotSupported = 1_000_004,
7+
NoEmbeddingsProviderConfigured = 1_000_005,
8+
AtlasVectorSearchIndexNotFound = 1_000_006,
9+
AtlasVectorSearchInvalidQuery = 1_000_007,
710
}
811

912
export class MongoDBError<ErrorCode extends ErrorCodes = ErrorCodes> extends Error {

src/common/search/embeddingsProvider.ts

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,40 +10,17 @@ const zEmbeddingsInput = z.string();
1010
type EmbeddingsInput = z.infer<typeof zEmbeddingsInput>;
1111
type Embeddings = number[];
1212

13+
type EmbeddingParameters = {
14+
numDimensions: number;
15+
quantization: string;
16+
inputType: "query" | "document";
17+
};
18+
1319
interface EmbeddingsProvider<SupportedModels extends string> {
14-
embed(modelId: SupportedModels, content: EmbeddingsInput[], parameters: unknown): Promise<Embeddings[]>;
20+
embed(modelId: SupportedModels, content: EmbeddingsInput[], parameters: EmbeddingParameters): Promise<Embeddings[]>;
1521
}
1622

17-
const zVoyageSupportedDimensions = z
18-
.union([z.literal(256), z.literal(512), z.literal(1024), z.literal(2048)])
19-
.default(1024);
20-
21-
const zVoyageQuantization = z.enum(["float", "int8", "binary", "ubinary"]).default("float");
22-
const zVoyageInputType = z.enum(["query", "document"]);
23-
2423
export const zVoyageModels = z.enum(["voyage-3-large", "voyage-3.5", "voyage-3.5-lite", "voyage-code-3"]);
25-
export const zVoyageParameters = {
26-
"voyage-3-large": z.object({
27-
inputType: zVoyageInputType,
28-
outputDimensions: zVoyageSupportedDimensions,
29-
outputDtype: zVoyageQuantization,
30-
}),
31-
"voyage-3.5": z.object({
32-
inputType: zVoyageInputType,
33-
outputDimensions: zVoyageSupportedDimensions,
34-
outputDtype: zVoyageQuantization,
35-
}),
36-
"voyage-3.5-lite": z.object({
37-
inputType: zVoyageInputType,
38-
outputDimensions: zVoyageSupportedDimensions,
39-
outputDtype: zVoyageQuantization,
40-
}),
41-
"voyage-code-3": z.object({
42-
inputType: zVoyageInputType,
43-
outputDimensions: zVoyageSupportedDimensions,
44-
outputDtype: zVoyageQuantization,
45-
}),
46-
} as const;
4724

4825
type VoyageModels = z.infer<typeof zVoyageModels>;
4926
class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels> {
@@ -68,12 +45,32 @@ class VoyageEmbeddingsProvider implements EmbeddingsProvider<VoyageModels> {
6845
async embed<Model extends VoyageModels>(
6946
modelId: Model,
7047
content: EmbeddingsInput[],
71-
parameters: z.infer<(typeof zVoyageParameters)[Model]>
48+
parameters: EmbeddingParameters
7249
): Promise<Embeddings[]> {
50+
const voyageParameters = {
51+
inputType: parameters.inputType,
52+
outputDimensions: parameters.numDimensions,
53+
quantization: "float", // it is hardcoded on purpose as we don't do quantization yet
54+
};
55+
7356
const model = this.voyage.textEmbeddingModel(modelId);
74-
const { embeddings } = await embedMany({ model, values: content, providerOptions: { voyage: parameters } });
57+
const { embeddings } = await embedMany({
58+
model,
59+
values: content,
60+
providerOptions: { voyage: voyageParameters },
61+
});
62+
7563
return embeddings;
7664
}
65+
66+
async embedOne<Model extends VoyageModels>(
67+
modelId: Model,
68+
content: EmbeddingsInput,
69+
parameters: EmbeddingParameters
70+
): Promise<Embeddings> {
71+
const embeddings = await this.embed(modelId, [content], parameters);
72+
return embeddings[0] ?? [];
73+
}
7774
}
7875

7976
export function getEmbeddingsProvider(userConfig: UserConfig): EmbeddingsProvider<VoyageModels> | undefined {
@@ -83,3 +80,6 @@ export function getEmbeddingsProvider(userConfig: UserConfig): EmbeddingsProvide
8380

8481
return undefined;
8582
}
83+
84+
export const zSupportedEmbeddingModels = zVoyageModels;
85+
export type SupportedEmbeddingModels = z.infer<typeof zSupportedEmbeddingModels>;

src/common/search/vectorSearchEmbeddingsManager.ts

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ import { BSON, type Document } from "bson";
33
import type { UserConfig } from "../config.js";
44
import type { ConnectionManager } from "../connectionManager.js";
55
import z from "zod";
6+
import { ErrorCodes, MongoDBError } from "../errors.js";
7+
import { getEmbeddingsProvider } from "./embeddingsProvider.js";
8+
import type { SupportedEmbeddingModels } from "./embeddingsProvider.js";
69

710
export const similarityEnum = z.enum(["cosine", "euclidean", "dotProduct"]);
811
export type Similarity = z.infer<typeof similarityEnum>;
@@ -216,6 +219,94 @@ export class VectorSearchEmbeddingsManager {
216219
return undefined;
217220
}
218221

222+
public async generateEmbeddings({
223+
database,
224+
collection,
225+
path,
226+
model,
227+
rawValues,
228+
inputType,
229+
}: {
230+
database: string;
231+
collection: string;
232+
path: string;
233+
model: SupportedEmbeddingModels;
234+
rawValues: string[];
235+
inputType: "query" | "document";
236+
}): Promise<unknown[]> {
237+
const provider = await this.assertAtlasSearchIsAvailable();
238+
if (!provider) {
239+
throw new MongoDBError(
240+
ErrorCodes.AtlasSearchNotSupported,
241+
"Atlas Search is not supported in this cluster."
242+
);
243+
}
244+
245+
const embeddingsProvider = getEmbeddingsProvider(this.config);
246+
247+
if (!embeddingsProvider) {
248+
throw new MongoDBError(ErrorCodes.NoEmbeddingsProviderConfigured, "No embeddings provider configured.");
249+
}
250+
251+
const embeddingInfoForCollection = await this.embeddingsForNamespace({ database, collection });
252+
const embeddingInfoForPath = embeddingInfoForCollection.find((definition) => definition.path === path);
253+
254+
if (!embeddingInfoForPath) {
255+
throw new MongoDBError(
256+
ErrorCodes.AtlasVectorSearchIndexNotFound,
257+
`No Vector Search index found for path "${path}" in namespace "${database}.${collection}"`
258+
);
259+
}
260+
261+
const providerEmbeddings = await embeddingsProvider.embed(model, rawValues, {
262+
inputType,
263+
numDimensions: embeddingInfoForPath.numDimensions,
264+
quantization: embeddingInfoForPath.quantization,
265+
});
266+
267+
if (this.config.disableEmbeddingsValidation) {
268+
return providerEmbeddings;
269+
}
270+
271+
const hasDocuments = await provider.estimatedDocumentCount(database, collection);
272+
if (!hasDocuments) {
273+
return providerEmbeddings;
274+
}
275+
276+
const oneDocument: Document = await provider
277+
.aggregate(database, collection, [{ $sample: { size: 1 } }, { $project: { embeddings: path } }])
278+
.next();
279+
280+
if (!oneDocument) {
281+
return providerEmbeddings;
282+
}
283+
284+
const sampleEmbeddings = oneDocument.embeddings;
285+
const adaptedEmbeddings = providerEmbeddings.map((embeddings) => {
286+
// now map based on the sample embeddings
287+
if (Array.isArray(sampleEmbeddings) && Array.isArray(embeddings)) {
288+
return providerEmbeddings;
289+
}
290+
if (sampleEmbeddings instanceof BSON.Binary && Array.isArray(providerEmbeddings)) {
291+
if (this.matches(() => sampleEmbeddings.toBits())) {
292+
return BSON.Binary.fromBits(embeddings);
293+
}
294+
if (this.matches(() => sampleEmbeddings.toInt8Array())) {
295+
return BSON.Binary.fromInt8Array(new Int8Array(embeddings));
296+
}
297+
if (this.matches(() => sampleEmbeddings.toFloat32Array())) {
298+
return BSON.Binary.fromFloat32Array(new Float32Array(embeddings));
299+
}
300+
if (this.matches(() => sampleEmbeddings.toPackedBits())) {
301+
return BSON.Binary.fromPackedBits(new Uint8Array(embeddings));
302+
}
303+
}
304+
return embeddings;
305+
});
306+
307+
return adaptedEmbeddings;
308+
}
309+
219310
private isANumber(value: unknown): boolean {
220311
if (typeof value === "number") {
221312
return true;
@@ -232,4 +323,13 @@ export class VectorSearchEmbeddingsManager {
232323

233324
return false;
234325
}
326+
327+
private matches(fn: () => unknown): boolean {
328+
try {
329+
fn();
330+
return true;
331+
} catch {
332+
return false;
333+
}
334+
}
235335
}

src/tools/mongodb/read/aggregate.ts

Lines changed: 105 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,53 @@ import { operationWithFallback } from "../../../helpers/operationWithFallback.js
1313
import { AGG_COUNT_MAX_TIME_MS_CAP, ONE_MB, CURSOR_LIMITS_TO_LLM_TEXT } from "../../../helpers/constants.js";
1414
import { zEJSON } from "../../args.js";
1515
import { LogId } from "../../../common/logger.js";
16+
import { SupportedEmbeddingModels, zSupportedEmbeddingModels } from "../../../common/search/embeddingsProvider.js";
17+
18+
const AnyStage = zEJSON();
19+
const VectorSearchStage = z.object({
20+
$vectorSearch: z
21+
.object({
22+
exact: z
23+
.boolean()
24+
.optional()
25+
.default(false)
26+
.describe(
27+
"When true, uses an ENN algorithm, otherwise uses ANN. Using ENN is not compatible with numCandidates, in that case, numCandidates must be left empty."
28+
),
29+
index: z.string().optional().default("default"),
30+
path: z
31+
.string()
32+
.describe(
33+
"Field, in dot notation, where to search. There must be a vector search index for that field. Note to LLM: When unsure, use the 'collection-indexes' tool to validate that the field is indexed with a vector search index."
34+
),
35+
queryVector: z
36+
.union([z.string(), z.array(z.number())])
37+
.describe(
38+
"The content to search for. If a string, the embeddingModel is mandatory, if not, the embeddingModel is ignored."
39+
),
40+
numCandidates: z
41+
.number()
42+
.int()
43+
.positive()
44+
.optional()
45+
.describe("Number of candidates for the ANN algorithm. Only valid when exact is false."),
46+
limit: z.number().int().positive().optional().default(10),
47+
filter: zEJSON()
48+
.optional()
49+
.describe("MQL filter that can only use pre-filter fields from the index definition."),
50+
embeddingModel: zSupportedEmbeddingModels.describe(
51+
"The embedding model to use to generate embeddings before search. Note to LLM: If unsure, ask the user before providing one."
52+
),
53+
})
54+
.passthrough(),
55+
});
1656

1757
export const AggregateArgs = {
18-
pipeline: z.array(zEJSON()).describe("An array of aggregation stages to execute"),
58+
pipeline: z
59+
.array(z.union([AnyStage, VectorSearchStage]))
60+
.describe(
61+
"An array of aggregation stages to execute. $vectorSearch can only appear as the first stage of the aggregation pipeline or as the first stage of a $unionWith subpipeline."
62+
),
1963
responseBytesLimit: z.number().optional().default(ONE_MB).describe(`\
2064
The maximum number of bytes to return in the response. This value is capped by the server’s configured maxBytesPerQuery and cannot be exceeded. \
2165
Note to LLM: If the entire aggregation result is required, use the "export" tool instead of increasing this limit.\
@@ -38,8 +82,7 @@ export class AggregateTool extends MongoDBToolBase {
3882
let aggregationCursor: AggregationCursor | undefined = undefined;
3983
try {
4084
const provider = await this.ensureConnected();
41-
42-
this.assertOnlyUsesPermittedStages(pipeline);
85+
await this.assertOnlyUsesPermittedStages(pipeline);
4386

4487
// Check if aggregate operation uses an index if enabled
4588
if (this.config.indexCheck) {
@@ -50,6 +93,12 @@ export class AggregateTool extends MongoDBToolBase {
5093
});
5194
}
5295

96+
pipeline = await this.replaceRawValuesWithEmbeddingsIfNecessary({
97+
database,
98+
collection,
99+
pipeline,
100+
});
101+
53102
const cappedResultsPipeline = [...pipeline];
54103
if (this.config.maxDocumentsPerQuery > 0) {
55104
cappedResultsPipeline.push({ $limit: this.config.maxDocumentsPerQuery });
@@ -107,8 +156,10 @@ export class AggregateTool extends MongoDBToolBase {
107156
}
108157
}
109158

110-
private assertOnlyUsesPermittedStages(pipeline: Record<string, unknown>[]): void {
159+
private async assertOnlyUsesPermittedStages(pipeline: Record<string, unknown>[]): Promise<void> {
111160
const writeOperations: OperationType[] = ["update", "create", "delete"];
161+
const isSearchSupported = await this.session.isSearchSupported();
162+
112163
let writeStageForbiddenError = "";
113164

114165
if (this.config.readOnly) {
@@ -118,14 +169,17 @@ export class AggregateTool extends MongoDBToolBase {
118169
"When 'create', 'update', or 'delete' operations are disabled, you can not run pipelines with $out or $merge stages.";
119170
}
120171

121-
if (!writeStageForbiddenError) {
122-
return;
123-
}
124-
125172
for (const stage of pipeline) {
126-
if (stage.$out || stage.$merge) {
173+
if ((stage.$out || stage.$merge) && writeStageForbiddenError) {
127174
throw new MongoDBError(ErrorCodes.ForbiddenWriteOperation, writeStageForbiddenError);
128175
}
176+
177+
if (stage.$vectorSearch && !isSearchSupported) {
178+
throw new MongoDBError(
179+
ErrorCodes.AtlasSearchNotSupported,
180+
"Atlas Search is not supported in this cluster."
181+
);
182+
}
129183
}
130184
}
131185

@@ -160,6 +214,48 @@ export class AggregateTool extends MongoDBToolBase {
160214
}, undefined);
161215
}
162216

217+
private async replaceRawValuesWithEmbeddingsIfNecessary({
218+
database,
219+
collection,
220+
pipeline,
221+
}: {
222+
database: string;
223+
collection: string;
224+
pipeline: Document[];
225+
}): Promise<Document[]> {
226+
for (const stage of pipeline) {
227+
if (stage.$vectorSearch) {
228+
if ("queryVector" in stage.$vectorSearch && Array.isArray(stage.$vectorSearch.queryVector)) {
229+
// if it's already embeddings, don't do anything
230+
continue;
231+
}
232+
233+
if (!("embeddingModel" in stage.$vectorSearch)) {
234+
throw new MongoDBError(
235+
ErrorCodes.AtlasVectorSearchInvalidQuery,
236+
"embeddingModel is mandatory if queryVector is a raw string."
237+
);
238+
}
239+
240+
const model = stage.$vectorSearch.embeddingModel as SupportedEmbeddingModels;
241+
delete stage.$vectorSearch.embeddingModel;
242+
243+
const [embeddings] = await this.session.vectorSearchEmbeddingsManager.generateEmbeddings({
244+
database,
245+
collection,
246+
path: stage.$vectorSearch.path,
247+
model,
248+
rawValues: stage.$vectorSearch.queryVector,
249+
inputType: "query",
250+
});
251+
252+
stage.$vectorSearch.queryVector = embeddings;
253+
}
254+
}
255+
256+
return pipeline;
257+
}
258+
163259
private generateMessage({
164260
aggResultsCount,
165261
documents,

0 commit comments

Comments
 (0)