diff --git a/src/tools/mongodb/create/createVectorIndex.ts b/src/tools/mongodb/create/createVectorIndex.ts new file mode 100644 index 00000000..55ac2c8d --- /dev/null +++ b/src/tools/mongodb/create/createVectorIndex.ts @@ -0,0 +1,44 @@ +import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { buildVectorFields, DbOperationArgs, MongoDBToolBase, VectorIndexArgs } from "../mongodbTool.js"; +import { OperationType, ToolArgs } from "../../tool.js"; + +const VECTOR_INDEX_TYPE = "vectorSearch"; +export class CreateVectorIndexTool extends MongoDBToolBase { + protected name = "create-vector-index"; + protected description = "Create an Atlas Vector Search Index for a collection."; + protected argsShape = { + ...DbOperationArgs, + name: VectorIndexArgs.name, + vectorDefinition: VectorIndexArgs.vectorDefinition, + filterFields: VectorIndexArgs.filterFields, + }; + + protected operationType: OperationType = "create"; + + protected async execute({ + database, + collection, + name, + vectorDefinition, + filterFields, + }: ToolArgs): Promise { + const provider = await this.ensureConnected(); + + const indexes = await provider.createSearchIndexes(database, collection, [ + { + name, + type: VECTOR_INDEX_TYPE, + definition: { fields: buildVectorFields(vectorDefinition, filterFields) }, + }, + ]); + + return { + content: [ + { + text: `Created the vector index ${indexes[0]} on collection "${collection}" in database "${database}"`, + type: "text", + }, + ], + }; + } +} diff --git a/src/tools/mongodb/mongodbTool.ts b/src/tools/mongodb/mongodbTool.ts index 5d387b25..db4e0e2a 100644 --- a/src/tools/mongodb/mongodbTool.ts +++ b/src/tools/mongodb/mongodbTool.ts @@ -1,5 +1,5 @@ import { z } from "zod"; -import { ToolArgs, ToolBase, ToolCategory, TelemetryToolMetadata } from "../tool.js"; +import { TelemetryToolMetadata, ToolArgs, ToolBase, ToolCategory } from "../tool.js"; import { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver"; import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; import { ErrorCodes, MongoDBError } from "../../errors.js"; @@ -10,6 +10,64 @@ export const DbOperationArgs = { collection: z.string().describe("Collection name"), }; +export enum VectorFieldType { + VECTOR = "vector", + FILTER = "filter", +} +export const VectorIndexArgs = { + name: z.string().describe("The name of the index"), + vectorDefinition: z + .object({ + path: z + .string() + .min(1) + .describe( + "Name of the field to index. For nested fields, use dot notation to specify path to embedded fields." + ), + numDimensions: z + .number() + .int() + .min(1) + .max(8192) + .describe("Number of vector dimensions to enforce at index-time and query-time."), + similarity: z + .enum(["euclidean", "cosine", "dotProduct"]) + .describe("Vector similarity function to use to search for top K-nearest neighbors."), + quantization: z + .enum(["none", "scalar", "binary"]) + .default("none") + .optional() + .describe( + "Automatic vector quantization. Use this setting only if your embeddings are float or double vectors." + ), + }) + .describe("The vector index definition."), + filterFields: z + .array( + z.object({ + path: z + .string() + .min(1) + .describe( + "Name of the field to filter by. For nested fields, use dot notation to specify path to embedded fields." + ), + }) + ) + .optional() + .describe("Additional indexed fields that pre-filter data."), +}; + +type VectorDefinitionType = z.infer; +type FilterFieldsType = z.infer; +export function buildVectorFields(vectorDefinition: VectorDefinitionType, filterFields: FilterFieldsType): object[] { + const typedVectorField = { ...vectorDefinition, type: VectorFieldType.VECTOR }; + const typedFilterFields = (filterFields ?? []).map((f) => ({ + ...f, + type: VectorFieldType.FILTER, + })); + return [typedVectorField, ...typedFilterFields]; +} + export const SearchIndexOperationArgs = { database: z.string().describe("Database name"), collection: z.string().describe("Collection name"), diff --git a/src/tools/mongodb/tools.ts b/src/tools/mongodb/tools.ts index 34074a07..ac79121c 100644 --- a/src/tools/mongodb/tools.ts +++ b/src/tools/mongodb/tools.ts @@ -18,6 +18,8 @@ import { DropCollectionTool } from "./delete/dropCollection.js"; import { ExplainTool } from "./metadata/explain.js"; import { CreateCollectionTool } from "./create/createCollection.js"; import { LogsTool } from "./metadata/logs.js"; +import { CreateVectorIndexTool } from "./create/createVectorIndex.js"; +import { UpdateVectorIndexTool } from "./update/updateVectorIndex.js"; import { CollectionSearchIndexesTool } from "./read/collectionSearchIndexes.js"; import { DropSearchIndexTool } from "./delete/dropSearchIndex.js"; @@ -43,5 +45,7 @@ export const MongoDbTools = [ ExplainTool, CreateCollectionTool, LogsTool, + CreateVectorIndexTool, + UpdateVectorIndexTool, DropSearchIndexTool, ]; diff --git a/src/tools/mongodb/update/updateVectorIndex.ts b/src/tools/mongodb/update/updateVectorIndex.ts new file mode 100644 index 00000000..476cfc5e --- /dev/null +++ b/src/tools/mongodb/update/updateVectorIndex.ts @@ -0,0 +1,41 @@ +import { CallToolResult } from "@modelcontextprotocol/sdk/types.js"; +import { buildVectorFields, DbOperationArgs, MongoDBToolBase, VectorIndexArgs } from "../mongodbTool.js"; +import { OperationType, ToolArgs } from "../../tool.js"; + +export class UpdateVectorIndexTool extends MongoDBToolBase { + protected name = "update-vector-index"; + protected description = "Updates an Atlas Search vector for a collection"; + protected argsShape = { + ...DbOperationArgs, + name: VectorIndexArgs.name, + vectorDefinition: VectorIndexArgs.vectorDefinition, + filterFields: VectorIndexArgs.filterFields, + }; + + protected operationType: OperationType = "create"; + + protected async execute({ + database, + collection, + name, + vectorDefinition, + filterFields, + }: ToolArgs): Promise { + const provider = await this.ensureConnected(); + + // @ts-expect-error: Interface expects a SearchIndexDefinition {definition: {fields}}. However, + // passing fields at the root level is necessary for the call to succeed. + await provider.updateSearchIndex(database, collection, name, { + fields: buildVectorFields(vectorDefinition, filterFields), + }); + + return { + content: [ + { + text: `Successfully updated vector index "${name}" on collection "${collection}" in database "${database}"`, + type: "text", + }, + ], + }; + } +}