diff --git a/package-lock.json b/package-lock.json index 80ae040abf8..f54bb7507a4 100644 --- a/package-lock.json +++ b/package-lock.json @@ -49297,7 +49297,8 @@ "react": "^17.0.2", "react-redux": "^8.1.3", "redux": "^4.2.1", - "redux-thunk": "^2.4.2" + "redux-thunk": "^2.4.2", + "zod": "^3.25.76" }, "devDependencies": { "@mongodb-js/connection-info": "^0.17.2", @@ -49383,6 +49384,15 @@ "url": "https://opencollective.com/sinon" } }, + "packages/compass-generative-ai/node_modules/zod": { + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==", + "license": "MIT", + "funding": { + "url": "https://github.com/sponsors/colinhacks" + } + }, "packages/compass-global-writes": { "name": "@mongodb-js/compass-global-writes", "version": "1.30.0", @@ -62101,7 +62111,8 @@ "redux-thunk": "^2.4.2", "sinon": "^9.2.3", "typescript": "^5.9.2", - "xvfb-maybe": "^0.2.1" + "xvfb-maybe": "^0.2.1", + "zod": "^3.25.76" }, "dependencies": { "diff": { @@ -62145,6 +62156,11 @@ "nise": "^4.0.4", "supports-color": "^7.1.0" } + }, + "zod": { + "version": "3.25.76", + "resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz", + "integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==" } } }, diff --git a/packages/compass-generative-ai/package.json b/packages/compass-generative-ai/package.json index 2c98a4cd387..bac59b8e581 100644 --- a/packages/compass-generative-ai/package.json +++ b/packages/compass-generative-ai/package.json @@ -67,7 +67,8 @@ "react": "^17.0.2", "react-redux": "^8.1.3", "redux": "^4.2.1", - "redux-thunk": "^2.4.2" + "redux-thunk": "^2.4.2", + "zod": "^3.25.76" }, "devDependencies": { "@mongodb-js/connection-info": "^0.17.2", diff --git a/packages/compass-generative-ai/src/atlas-ai-errors.ts b/packages/compass-generative-ai/src/atlas-ai-errors.ts new file mode 100644 index 00000000000..992377cf18b --- /dev/null +++ b/packages/compass-generative-ai/src/atlas-ai-errors.ts @@ -0,0 +1,21 @@ +/** + * Occurs when the input to the AtlasAiService is understood but invalid. + */ +class AtlasAiServiceInvalidInputError extends Error { + constructor(message: string) { + super(message); + this.name = 'AtlasAiServiceInvalidInputError'; + } +} + +/** + * Thrown when the API response cannot be parsed into the expected shape.. + */ +class AtlasAiServiceApiResponseParseError extends Error { + constructor(message: string) { + super(message); + this.name = 'AtlasAiServiceApiResponseParseError'; + } +} + +export { AtlasAiServiceInvalidInputError, AtlasAiServiceApiResponseParseError }; diff --git a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts index 193a437acff..5d86731562f 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.spec.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.spec.ts @@ -1,6 +1,10 @@ import Sinon from 'sinon'; import { expect } from 'chai'; import { AtlasAiService } from './atlas-ai-service'; +import { + AtlasAiServiceInvalidInputError, + AtlasAiServiceApiResponseParseError, +} from './atlas-ai-errors'; import type { PreferencesAccess } from 'compass-preferences-model'; import { createSandboxFromDefaultPreferences } from 'compass-preferences-model'; import { createNoopLogger } from '@mongodb-js/compass-logging/provider'; @@ -92,6 +96,8 @@ describe('AtlasAiService', function () { 'mql-aggregation': '/cloud/ai/v1/groups/testProject/mql-aggregation?request_id=abc', 'mql-query': '/cloud/ai/v1/groups/testProject/mql-query?request_id=abc', + 'mock-data-schema': + '/cloud/ai/v1/groups/testProject/mock-data-schema?request_id=abc', }, }, ] as const; @@ -110,7 +116,7 @@ describe('AtlasAiService', function () { }); }); - describe('ai api calls', function () { + describe('getQueryFromUserInput and getAggregationFromUserInput', function () { beforeEach(async function () { // Enable the AI feature const fetchStub = sandbox.stub().resolves( @@ -369,6 +375,292 @@ describe('AtlasAiService', function () { expect(currentPreferences.optInGenAIFeatures).to.equal(true); }); }); + + describe('getMockDataSchema', function () { + beforeEach(async function () { + // Enable the AI feature + const fetchStub = sandbox.stub().resolves( + makeResponse({ + features: { + GEN_AI_COMPASS: { + enabled: true, + }, + }, + }) + ); + global.fetch = fetchStub; + await atlasAiService['setupAIAccess'](); + global.fetch = initialFetch; + }); + + const mockSchemaInput = { + collectionName: 'test-collection', + databaseName: 'test-db', + schema: { + name: { + type: 'string', + sampleValues: ['John', 'Jane', 'Bob'], + probability: 0.9, + }, + age: { + type: 'number', + sampleValues: [25, 30, 35], + probability: 0.8, + }, + }, + includeSampleValues: false, + }; + + if (apiURLPreset === 'admin-api') { + it('throws AtlasAiServiceInvalidInputError for admin-api preset', async function () { + try { + await atlasAiService.getMockDataSchema( + mockSchemaInput, + mockConnectionInfo + ); + expect.fail( + 'Expected getMockDataSchema to throw for admin-api preset' + ); + } catch (err) { + expect(err).to.be.instanceOf(AtlasAiServiceInvalidInputError); + expect((err as Error).message).to.match( + /mock-data-schema is not available for admin-api/i + ); + } + }); + } + + if (apiURLPreset === 'cloud') { + it('makes a post request to the correct endpoint', async function () { + const mockResponse = { + content: { + fields: [ + { + fieldPath: 'name', + mongoType: 'string', + fakerMethod: 'person.fullName', + fakerArgs: [], + isArray: false, + probability: 1.0, + }, + { + fieldPath: 'age', + mongoType: 'int', + fakerMethod: 'number.int', + fakerArgs: [{ json: '{"min": 18, "max": 65}' }], + isArray: false, + probability: 0.8, + }, + ], + }, + }; + const fetchStub = sandbox + .stub() + .resolves(makeResponse(mockResponse)); + global.fetch = fetchStub; + + const result = await atlasAiService.getMockDataSchema( + mockSchemaInput, + mockConnectionInfo + ); + + expect(fetchStub).to.have.been.calledOnce; + const { args } = fetchStub.firstCall; + expect(args[0]).to.eq( + '/cloud/ai/v1/groups/testProject/mock-data-schema' + ); + expect(result).to.deep.equal(mockResponse); + }); + + it('includes sample values by default (includeSampleValues=true)', async function () { + const mockResponse = { + content: { + fields: [ + { + fieldPath: 'name', + mongoType: 'string', + fakerMethod: 'person.fullName', + fakerArgs: [], + isArray: false, + probability: 1.0, + }, + { + fieldPath: 'age', + mongoType: 'int', + fakerMethod: 'number.int', + fakerArgs: [{ json: '{"min": 18, "max": 122}' }], + isArray: false, + probability: 0.8, + }, + ], + }, + }; + const fetchStub = sandbox + .stub() + .resolves(makeResponse(mockResponse)); + global.fetch = fetchStub; + + await atlasAiService.getMockDataSchema( + { ...mockSchemaInput, includeSampleValues: true }, + mockConnectionInfo + ); + + const { args } = fetchStub.firstCall; + const requestBody = JSON.parse(args[1].body); + + expect(requestBody.schema.name.sampleValues).to.deep.equal([ + 'John', + 'Jane', + 'Bob', + ]); + expect(requestBody.schema.age.sampleValues).to.deep.equal([ + 25, 30, 35, + ]); + }); + + it('excludes sample values when includeSampleValues=false', async function () { + const mockResponse = { + content: { + fields: [ + { + fieldPath: 'name', + mongoType: 'string', + fakerMethod: 'person.fullName', + fakerArgs: [], + isArray: false, + probability: 1.0, + }, + { + fieldPath: 'age', + mongoType: 'int', + fakerMethod: 'number.int', + fakerArgs: [{ json: '{"min": 18, "max": 65}' }], + isArray: false, + probability: 0.8, + }, + ], + }, + }; + const fetchStub = sandbox + .stub() + .resolves(makeResponse(mockResponse)); + global.fetch = fetchStub; + + await atlasAiService.getMockDataSchema( + mockSchemaInput, + mockConnectionInfo + ); + + const { args } = fetchStub.firstCall; + const requestBody = JSON.parse(args[1].body); + + expect(requestBody.schema.name).to.not.have.property( + 'sampleValues' + ); + expect(requestBody.schema.age).to.not.have.property('sampleValues'); + expect(requestBody.schema.name.type).to.equal('string'); + expect(requestBody.schema.age.probability).to.equal(0.8); + }); + + it('makes POST request with correct headers and body structure', async function () { + const mockResponse = { + content: { + fields: [ + { + fieldPath: 'name', + mongoType: 'string', + fakerMethod: 'person.fullName', + fakerArgs: [], + isArray: false, + probability: 1.0, + }, + { + fieldPath: 'age', + mongoType: 'int', + fakerMethod: 'number.int', + fakerArgs: [{ json: '{"min": 18, "max": 65}' }], + isArray: false, + probability: 0.8, + }, + ], + }, + }; + const fetchStub = sandbox + .stub() + .resolves(makeResponse(mockResponse)); + global.fetch = fetchStub; + + await atlasAiService.getMockDataSchema( + mockSchemaInput, + mockConnectionInfo + ); + + const { args } = fetchStub.firstCall; + + expect(args[1].method).to.equal('POST'); + expect(args[1].headers['Content-Type']).to.equal( + 'application/json' + ); + expect(args[1].headers['Accept']).to.equal('application/json'); + + const requestBody = JSON.parse(args[1].body); + expect(requestBody).to.have.property( + 'collectionName', + 'test-collection' + ); + expect(requestBody).to.have.property('databaseName', 'test-db'); + expect(requestBody).to.have.property('schema'); + }); + + it('throws AtlasAiServiceInvalidInputError when connection info lacks atlas metadata', async function () { + const connectionInfoWithoutAtlas = { + ...mockConnectionInfo, + atlasMetadata: undefined, + }; + + try { + await atlasAiService.getMockDataSchema( + mockSchemaInput, + connectionInfoWithoutAtlas as any + ); + expect.fail('Expected getMockDataSchema to throw'); + } catch (err) { + expect(err).to.be.instanceOf(AtlasAiServiceInvalidInputError); + expect((err as Error).message).to.match( + /atlasMetadata is not available/i + ); + } + }); + + it('throws AtlasAiServiceApiResponseParseError when API response has invalid format', async function () { + const invalidMockResponse = { + invalidField: 'invalid data', + content: { + wrongFieldName: [], + }, + }; + const fetchStub = sandbox + .stub() + .resolves(makeResponse(invalidMockResponse)); + global.fetch = fetchStub; + + try { + await atlasAiService.getMockDataSchema( + mockSchemaInput, + mockConnectionInfo + ); + expect.fail( + 'Expected getMockDataSchema to throw AtlasAiServiceApiResponseParseError' + ); + } catch (err) { + expect(err).to.be.instanceOf(AtlasAiServiceApiResponseParseError); + expect((err as Error).message).to.equal( + 'Response does not match expected schema' + ); + } + }); + } + }); }); } }); diff --git a/packages/compass-generative-ai/src/atlas-ai-service.ts b/packages/compass-generative-ai/src/atlas-ai-service.ts index 8c5a39b1aa1..9cb4c23b803 100644 --- a/packages/compass-generative-ai/src/atlas-ai-service.ts +++ b/packages/compass-generative-ai/src/atlas-ai-service.ts @@ -9,9 +9,15 @@ import type { ConnectionInfo } from '@mongodb-js/compass-connections/provider'; import type { Document } from 'mongodb'; import type { Logger } from '@mongodb-js/compass-logging'; import { EJSON } from 'bson'; +import { z } from 'zod'; import { getStore } from './store/atlas-ai-store'; import { optIntoGenAIWithModalPrompt } from './store/atlas-optin-reducer'; import { signIntoAtlasWithModalPrompt } from './store/atlas-signin-reducer'; +import { + AtlasAiServiceInvalidInputError, + AtlasAiServiceApiResponseParseError, +} from './atlas-ai-errors'; +import { mongoLogId } from '@mongodb-js/compass-logging/provider'; type GenerativeAiInput = { userInput: string; @@ -204,9 +210,57 @@ const aiURLConfig = { cloud: { aggregation: (groupId: string) => `ai/v1/groups/${groupId}/mql-aggregation`, query: (groupId: string) => `ai/v1/groups/${groupId}/mql-query`, + 'mock-data-schema': (groupId: string) => + `ai/v1/groups/${groupId}/mock-data-schema`, }, } as const; -type AIEndpoint = 'query' | 'aggregation'; + +export interface MockDataSchemaRawField { + type: string; + sampleValues?: unknown[]; + probability?: number; +} + +export interface MockDataSchemaRequest { + collectionName: string; + databaseName: string; + schema: Record; + validationRules?: Record | null; + includeSampleValues?: boolean; +} + +export const MockDataSchemaResponseShape = z.object({ + content: z.object({ + fields: z.array( + z.object({ + fieldPath: z.string(), + mongoType: z.string(), + fakerMethod: z.string(), + fakerArgs: z.array( + z.union([ + z.object({ + json: z.string(), + }), + z.string(), + z.number(), + z.boolean(), + ]) + ), + isArray: z.boolean(), + probability: z.number(), + }) + ), + }), +}); + +export type MockDataSchemaResponse = z.infer< + typeof MockDataSchemaResponseShape +>; + +/** + * The type of resource from the natural language query REST API + */ +type AIResourceType = 'query' | 'aggregation' | 'mock-data-schema'; export class AtlasAiService { private initPromise: Promise | null = null; @@ -235,23 +289,33 @@ export class AtlasAiService { this.initPromise = this.setupAIAccess(); } + /** + * @throws {AtlasAiServiceInvalidInputError} when given invalid arguments + */ private getUrlForEndpoint( - urlId: AIEndpoint, + resourceType: AIResourceType, connectionInfo?: ConnectionInfo ) { if (this.apiURLPreset === 'cloud') { const atlasMetadata = connectionInfo?.atlasMetadata; if (!atlasMetadata) { - throw new Error( + throw new AtlasAiServiceInvalidInputError( "Can't perform generative ai request: atlasMetadata is not available" ); } return this.atlasService.cloudEndpoint( - aiURLConfig[this.apiURLPreset][urlId](atlasMetadata.projectId) + aiURLConfig[this.apiURLPreset][resourceType](atlasMetadata.projectId) ); } - const urlPath = aiURLConfig[this.apiURLPreset][urlId]; + + if (resourceType === 'mock-data-schema') { + throw new AtlasAiServiceInvalidInputError( + "Can't perform generative ai request: mock-data-schema is not available for admin-api" + ); + } + + const urlPath = aiURLConfig[this.apiURLPreset][resourceType]; return this.atlasService.adminApiEndpoint(urlPath); } @@ -395,6 +459,59 @@ export class AtlasAiService { ); } + async getMockDataSchema( + input: MockDataSchemaRequest, + connectionInfo: ConnectionInfo + ): Promise { + const { collectionName, databaseName } = input; + let schema = input.schema; + + const url = this.getUrlForEndpoint('mock-data-schema', connectionInfo); + + if (!input.includeSampleValues) { + const newSchema: Record< + string, + Omit + > = {}; + for (const [k, v] of Object.entries(schema)) { + newSchema[k] = { type: v.type, probability: v.probability }; + } + schema = newSchema; + } + + const res = await this.atlasService.authenticatedFetch(url, { + method: 'POST', + body: JSON.stringify({ + collectionName, + databaseName, + schema, + }), + headers: { + 'Content-Type': 'application/json', + Accept: 'application/json', + }, + }); + + try { + const data = await res.json(); + return MockDataSchemaResponseShape.parse(data); + } catch (err) { + const errorMessage = err instanceof Error ? err.stack : String(err); + this.logger.log.error( + mongoLogId(1_001_000_311), + 'AtlasAiService', + 'Failed to parse mock data schema response with expected schema', + { + namespace: `${databaseName}.${collectionName}`, + message: errorMessage, + } + ); + throw new AtlasAiServiceApiResponseParseError( + 'Response does not match expected schema' + ); + } + } + async optIntoGenAIFeatures() { if (this.apiURLPreset === 'cloud') { // Performs a post request to Atlas to set the user opt in preference to true. diff --git a/packages/compass-generative-ai/src/index.ts b/packages/compass-generative-ai/src/index.ts index da945a0f8c7..525d5b12008 100644 --- a/packages/compass-generative-ai/src/index.ts +++ b/packages/compass-generative-ai/src/index.ts @@ -23,3 +23,16 @@ export { GenerativeAIInput, createAIPlaceholderHTMLPlaceholder, } from './components'; + +export { MockDataSchemaResponseShape } from './atlas-ai-service'; + +export { + AtlasAiServiceInvalidInputError, + AtlasAiServiceApiResponseParseError, +} from './atlas-ai-errors'; + +export type { + MockDataSchemaRequest, + MockDataSchemaRawField, + MockDataSchemaResponse, +} from './atlas-ai-service';