Skip to content

Commit 8ac71ba

Browse files
committed
chore: Add new session-level service for getting embeddings of a specific collection
1 parent c12de89 commit 8ac71ba

File tree

9 files changed

+74
-2
lines changed

9 files changed

+74
-2
lines changed
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-driver";
2+
import type { Document } from "bson";
3+
4+
type VectorFieldIndexDefinition = {
5+
type: "vector";
6+
path: string;
7+
numDimensions: number;
8+
quantization: "none" | "scalar" | "binary";
9+
similarity: "euclidean" | "cosine" | "dotProduct";
10+
};
11+
12+
type EmbeddingNamespace = "${string}.${string}";
13+
export class VectorSearchEmbeddings {
14+
private embeddings: Map<EmbeddingNamespace, VectorFieldIndexDefinition[]>;
15+
16+
constructor() {
17+
this.embeddings = new Map();
18+
}
19+
20+
cleanupEmbeddingsForNamespace({ database, collection }: { database: string; collection: string }): void {
21+
const embeddingDefKey = `${database}.${collection}` as EmbeddingNamespace;
22+
this.embeddings.delete(embeddingDefKey);
23+
}
24+
25+
async embeddingsForNamespace({
26+
database,
27+
collection,
28+
provider,
29+
}: {
30+
database: string;
31+
collection: string;
32+
provider: NodeDriverServiceProvider;
33+
}): Promise<VectorFieldIndexDefinition[] | undefined> {
34+
const embeddingDefKey = `${database}.${collection}` as EmbeddingNamespace;
35+
const definition = this.embeddings.get(embeddingDefKey);
36+
37+
if (!definition) {
38+
const allSearchIndexes = await provider.getSearchIndexes(database, collection);
39+
const vectorSearchIndexes = allSearchIndexes.filter((index) => index.type === "vectorSearch");
40+
const vectorFields = vectorSearchIndexes
41+
// eslint-disable-next-line @typescript-eslint/no-unsafe-member-access
42+
.flatMap<Document>((index) => (index.latestDefinition?.fields as Document) ?? [])
43+
.filter((field) => this.isVectorFieldIndexDefinition(field));
44+
45+
this.embeddings.set(embeddingDefKey, vectorFields);
46+
return vectorFields;
47+
} else {
48+
return definition;
49+
}
50+
}
51+
52+
isVectorFieldIndexDefinition(doc: Document): doc is VectorFieldIndexDefinition {
53+
return doc["type"] === "vector";
54+
}
55+
}

src/common/session.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import type { NodeDriverServiceProvider } from "@mongosh/service-provider-node-d
1616
import { ErrorCodes, MongoDBError } from "./errors.js";
1717
import type { ExportsManager } from "./exportsManager.js";
1818
import type { Keychain } from "./keychain.js";
19+
import type { VectorSearchEmbeddings } from "./search/vectorSearchEmbeddings.js";
1920

2021
export interface SessionOptions {
2122
apiBaseUrl: string;
@@ -25,6 +26,7 @@ export interface SessionOptions {
2526
exportsManager: ExportsManager;
2627
connectionManager: ConnectionManager;
2728
keychain: Keychain;
29+
vectorSearchEmbeddings: VectorSearchEmbeddings;
2830
}
2931

3032
export type SessionEvents = {
@@ -40,6 +42,7 @@ export class Session extends EventEmitter<SessionEvents> {
4042
readonly connectionManager: ConnectionManager;
4143
readonly apiClient: ApiClient;
4244
readonly keychain: Keychain;
45+
readonly vectorSearchEmbeddings: VectorSearchEmbeddings;
4346

4447
mcpClient?: {
4548
name?: string;
@@ -57,6 +60,7 @@ export class Session extends EventEmitter<SessionEvents> {
5760
connectionManager,
5861
exportsManager,
5962
keychain,
63+
vectorSearchEmbeddings,
6064
}: SessionOptions) {
6165
super();
6266

@@ -73,6 +77,7 @@ export class Session extends EventEmitter<SessionEvents> {
7377
this.apiClient = new ApiClient({ baseUrl: apiBaseUrl, credentials }, logger);
7478
this.exportsManager = exportsManager;
7579
this.connectionManager = connectionManager;
80+
this.vectorSearchEmbeddings = vectorSearchEmbeddings;
7681
this.connectionManager.events.on("connection-success", () => this.emit("connect"));
7782
this.connectionManager.events.on("connection-time-out", (error) => this.emit("connection-error", error));
7883
this.connectionManager.events.on("connection-close", () => this.emit("disconnect"));

src/tools/mongodb/search/listSearchIndexes.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { EJSON } from "bson";
66

77
export type SearchIndexStatus = {
88
name: string;
9-
type: string;
9+
type: "search" | "vectorSearch";
1010
status: string;
1111
queryable: boolean;
1212
latestDefinition: Document;
@@ -54,7 +54,7 @@ export class ListSearchIndexesTool extends MongoDBToolBase {
5454
protected pickRelevantInformation(indexes: Record<string, unknown>[]): SearchIndexStatus[] {
5555
return indexes.map((index) => ({
5656
name: (index["name"] ?? "default") as string,
57-
type: (index["type"] ?? "UNKNOWN") as string,
57+
type: (index["type"] ?? "UNKNOWN") as "search" | "vectorSearch",
5858
status: (index["status"] ?? "UNKNOWN") as string,
5959
queryable: (index["queryable"] ?? false) as boolean,
6060
latestDefinition: index["latestDefinition"] as Document,

src/transports/base.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import {
1616
} from "../common/connectionErrorHandler.js";
1717
import type { CommonProperties } from "../telemetry/types.js";
1818
import { Elicitation } from "../elicitation.js";
19+
import { VectorSearchEmbeddings } from "../common/search/vectorSearchEmbeddings.js";
1920

2021
export type TransportRunnerConfig = {
2122
userConfig: UserConfig;
@@ -89,6 +90,7 @@ export abstract class TransportRunnerBase {
8990
exportsManager,
9091
connectionManager,
9192
keychain: Keychain.root,
93+
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
9294
});
9395

9496
const telemetry = Telemetry.create(session, this.userConfig, this.deviceId, {

tests/integration/helpers.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import { connectionErrorHandler } from "../../src/common/connectionErrorHandler.
2121
import { Keychain } from "../../src/common/keychain.js";
2222
import { Elicitation } from "../../src/elicitation.js";
2323
import type { MockClientCapabilities, createMockElicitInput } from "../utils/elicitationMocks.js";
24+
import { VectorSearchEmbeddings } from "../../src/common/search/vectorSearchEmbeddings.js";
2425

2526
export const driverOptions = setupDriverConfig({
2627
config,
@@ -101,6 +102,7 @@ export function setupIntegrationTest(
101102
exportsManager,
102103
connectionManager,
103104
keychain: new Keychain(),
105+
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
104106
});
105107

106108
// Mock hasValidAccessToken for tests

tests/integration/telemetry.test.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import { CompositeLogger } from "../../src/common/logger.js";
88
import { MCPConnectionManager } from "../../src/common/connectionManager.js";
99
import { ExportsManager } from "../../src/common/exportsManager.js";
1010
import { Keychain } from "../../src/common/keychain.js";
11+
import { VectorSearchEmbeddings } from "../../src/common/search/vectorSearchEmbeddings.js";
1112

1213
describe("Telemetry", () => {
1314
it("should resolve the actual device ID", async () => {
@@ -23,6 +24,7 @@ describe("Telemetry", () => {
2324
exportsManager: ExportsManager.init(config, logger),
2425
connectionManager: new MCPConnectionManager(config, driverOptions, logger, deviceId),
2526
keychain: new Keychain(),
27+
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
2628
}),
2729
config,
2830
deviceId

tests/integration/tools/mongodb/mongodbTool.test.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import { ErrorCodes } from "../../../../src/common/errors.js";
2020
import { Keychain } from "../../../../src/common/keychain.js";
2121
import { Elicitation } from "../../../../src/elicitation.js";
2222
import { MongoDbTools } from "../../../../src/tools/mongodb/tools.js";
23+
import { VectorSearchEmbeddings } from "../../../../src/common/search/vectorSearchEmbeddings.js";
2324

2425
const injectedErrorHandler: ConnectionErrorHandler = (error) => {
2526
switch (error.code) {
@@ -108,6 +109,7 @@ describe("MongoDBTool implementations", () => {
108109
exportsManager,
109110
connectionManager,
110111
keychain: new Keychain(),
112+
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
111113
});
112114
const telemetry = Telemetry.create(session, userConfig, deviceId);
113115

tests/unit/common/session.test.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { MCPConnectionManager } from "../../../src/common/connectionManager.js";
99
import { ExportsManager } from "../../../src/common/exportsManager.js";
1010
import { DeviceId } from "../../../src/helpers/deviceId.js";
1111
import { Keychain } from "../../../src/common/keychain.js";
12+
import { VectorSearchEmbeddings } from "../../../src/common/search/vectorSearchEmbeddings.js";
1213

1314
vi.mock("@mongosh/service-provider-node-driver");
1415

@@ -31,6 +32,7 @@ describe("Session", () => {
3132
exportsManager: ExportsManager.init(config, logger),
3233
connectionManager: new MCPConnectionManager(config, driverOptions, logger, mockDeviceId),
3334
keychain: new Keychain(),
35+
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
3436
});
3537

3638
MockNodeDriverServiceProvider.connect = vi.fn().mockResolvedValue({} as unknown as NodeDriverServiceProvider);

tests/unit/resources/common/debug.test.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import { MCPConnectionManager } from "../../../../src/common/connectionManager.j
99
import { ExportsManager } from "../../../../src/common/exportsManager.js";
1010
import { DeviceId } from "../../../../src/helpers/deviceId.js";
1111
import { Keychain } from "../../../../src/common/keychain.js";
12+
import { VectorSearchEmbeddings } from "../../../../src/common/search/vectorSearchEmbeddings.js";
1213

1314
describe("debug resource", () => {
1415
const logger = new CompositeLogger();
@@ -19,6 +20,7 @@ describe("debug resource", () => {
1920
exportsManager: ExportsManager.init(config, logger),
2021
connectionManager: new MCPConnectionManager(config, driverOptions, logger, deviceId),
2122
keychain: new Keychain(),
23+
vectorSearchEmbeddings: new VectorSearchEmbeddings(),
2224
});
2325
const telemetry = Telemetry.create(session, { ...config, telemetry: "disabled" }, deviceId);
2426

0 commit comments

Comments
 (0)