Skip to content

Commit 885cdba

Browse files
gribnoysupaddaleax
andauthored
chore(atlas-service): add validation for backend response (#4745)
* chore(atlas-service): add validation for backend response * chore(atlas-service): use some for boolean array check; fix test description Co-authored-by: Anna Henningsen <[email protected]> * chore(atlas-service): use inspect for better error messages Co-authored-by: Anna Henningsen <[email protected]> --------- Co-authored-by: Anna Henningsen <[email protected]>
1 parent 5f0e979 commit 885cdba

File tree

9 files changed

+229
-96
lines changed

9 files changed

+229
-96
lines changed

packages/atlas-service/src/main.spec.ts

Lines changed: 61 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,20 +14,6 @@ function getListenerCount(emitter: EventEmitter) {
1414
}, 0);
1515
}
1616

17-
const atlasAIServiceTests: {
18-
functionName: 'getQueryFromUserInput' | 'getAggregationFromUserInput';
19-
aiEndpoint: string;
20-
}[] = [
21-
{
22-
functionName: 'getQueryFromUserInput',
23-
aiEndpoint: 'mql-query',
24-
},
25-
{
26-
functionName: 'getAggregationFromUserInput',
27-
aiEndpoint: 'mql-aggregation',
28-
},
29-
];
30-
3117
describe('AtlasServiceMain', function () {
3218
const sandbox = Sinon.createSandbox();
3319

@@ -216,15 +202,46 @@ describe('AtlasServiceMain', function () {
216202
});
217203
});
218204

219-
for (const { functionName, aiEndpoint } of atlasAIServiceTests) {
205+
const atlasAIServiceTests = [
206+
{
207+
functionName: 'getQueryFromUserInput',
208+
aiEndpoint: 'mql-query',
209+
responses: {
210+
success: {
211+
content: { query: { filter: "{ test: 'pineapple' }" } },
212+
},
213+
invalid: [
214+
{},
215+
{ countent: {} },
216+
{ content: { qooery: {} } },
217+
{ content: { query: { filter: { foo: 1 } } } },
218+
],
219+
},
220+
},
221+
{
222+
functionName: 'getAggregationFromUserInput',
223+
aiEndpoint: 'mql-aggregation',
224+
responses: {
225+
success: {
226+
content: { aggregation: { pipeline: "[{ test: 'pineapple' }]" } },
227+
},
228+
invalid: [
229+
{},
230+
{ content: { aggregation: {} } },
231+
{ content: { aggrogation: {} } },
232+
{ content: { aggregation: { pipeline: true } } },
233+
],
234+
},
235+
},
236+
] as const;
237+
238+
for (const { functionName, aiEndpoint, responses } of atlasAIServiceTests) {
220239
describe(functionName, function () {
221240
it('makes a post request with the user input to the endpoint in the environment', async function () {
222241
AtlasService['fetch'] = sandbox.stub().resolves({
223242
ok: true,
224243
json() {
225-
return Promise.resolve({
226-
content: { query: { filter: "{ test: 'pineapple' }" } },
227-
});
244+
return Promise.resolve(responses.success);
228245
},
229246
}) as any;
230247

@@ -246,10 +263,28 @@ describe('AtlasServiceMain', function () {
246263
expect(args[1].body).to.eq(
247264
'{"userInput":"test","collectionName":"jam","databaseName":"peanut","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":1234}]}'
248265
);
249-
expect(res).to.have.nested.property(
250-
'content.query.filter',
251-
"{ test: 'pineapple' }"
252-
);
266+
expect(res).to.deep.eq(responses.success);
267+
});
268+
269+
it('should fail when response is not matching expected schema', async function () {
270+
for (const res of responses.invalid) {
271+
AtlasService['fetch'] = sandbox.stub().resolves({
272+
ok: true,
273+
json() {
274+
return Promise.resolve(res);
275+
},
276+
}) as any;
277+
try {
278+
await AtlasService[functionName]({
279+
userInput: 'test',
280+
collectionName: 'test',
281+
databaseName: 'peanut',
282+
});
283+
expect.fail(`Expected ${functionName} to throw`);
284+
} catch (err) {
285+
expect((err as Error).message).to.match(/Unexpected.+?response/);
286+
}
287+
}
253288
});
254289

255290
it('uses the abort signal in the fetch request', async function () {
@@ -289,7 +324,7 @@ describe('AtlasServiceMain', function () {
289324
AtlasService['fetch'] = sandbox.stub().resolves({
290325
ok: true,
291326
json() {
292-
return Promise.resolve({});
327+
return Promise.resolve(responses.success);
293328
},
294329
}) as any;
295330

@@ -356,13 +391,13 @@ describe('AtlasServiceMain', function () {
356391
AtlasService['fetch'] = sandbox.stub().resolves({
357392
ok: true,
358393
json() {
359-
return Promise.resolve({ test: 1 });
394+
return Promise.resolve(responses.success);
360395
},
361396
}) as any;
362397
AtlasService['oidcPluginLogger'].emit(
363398
'mongodb-oidc-plugin:refresh-started'
364399
);
365-
const [query] = await Promise.all([
400+
const [res] = await Promise.all([
366401
AtlasService[functionName]({
367402
userInput: 'test',
368403
collectionName: 'test',
@@ -378,7 +413,7 @@ describe('AtlasServiceMain', function () {
378413
);
379414
})(),
380415
]);
381-
expect(query).to.deep.eq({ test: 1 });
416+
expect(res).to.deep.eq(responses.success);
382417
});
383418
});
384419
}

packages/atlas-service/src/main.ts

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,14 @@ import type { Response } from 'node-fetch';
1919
import fetch from 'node-fetch';
2020
import type { SimplifiedSchema } from 'mongodb-schema';
2121
import type { Document } from 'mongodb';
22-
import type {
23-
AIAggregation,
24-
AIQuery,
25-
IntrospectInfo,
26-
Token,
27-
UserInfo,
22+
import {
23+
validateAIQueryResponse,
24+
type AIAggregation,
25+
type AIQuery,
26+
type IntrospectInfo,
27+
type Token,
28+
type UserInfo,
29+
validateAIAggregationResponse,
2830
} from './util';
2931
import {
3032
broadcast,
@@ -733,7 +735,7 @@ export class AtlasService {
733735
schema?: SimplifiedSchema;
734736
sampleDocuments?: Document[];
735737
signal?: AbortSignal;
736-
}) {
738+
}): Promise<AIAggregation> {
737739
throwIfAborted(signal);
738740
throwIfNetworkTrafficDisabled();
739741

@@ -780,7 +782,11 @@ export class AtlasService {
780782

781783
await throwIfNotOk(res);
782784

783-
return res.json() as Promise<AIAggregation>;
785+
const body = await res.json();
786+
787+
validateAIAggregationResponse(body);
788+
789+
return body;
784790
}
785791

786792
static async getQueryFromUserInput({
@@ -797,7 +803,7 @@ export class AtlasService {
797803
schema?: SimplifiedSchema;
798804
sampleDocuments?: Document[];
799805
signal?: AbortSignal;
800-
}) {
806+
}): Promise<AIQuery> {
801807
throwIfAborted(signal);
802808
throwIfNetworkTrafficDisabled();
803809

@@ -841,7 +847,11 @@ export class AtlasService {
841847

842848
await throwIfNotOk(res);
843849

844-
return res.json() as Promise<AIQuery>;
850+
const body = await res.json();
851+
852+
validateAIQueryResponse(body);
853+
854+
return body;
845855
}
846856

847857
static async onExit() {

packages/atlas-service/src/renderer.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,4 +110,11 @@ export class AtlasService implements AtlasServiceEmitter {
110110

111111
export { AtlasSignIn } from './components/atlas-signin';
112112

113-
export type { UserInfo, IntrospectInfo, Token } from './util';
113+
export type {
114+
UserInfo,
115+
IntrospectInfo,
116+
Token,
117+
AtlasServiceNetworkError,
118+
AIQuery,
119+
AIAggregation,
120+
} from './util';

packages/atlas-service/src/util.ts

Lines changed: 100 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import type * as plugin from '@mongodb-js/oidc-plugin';
2+
import util from 'util';
23

34
export type UserInfo = {
45
firstName: string;
@@ -11,16 +12,111 @@ export type IntrospectInfo = { active: boolean };
1112

1213
export type Token = plugin.IdPServerResponse;
1314

15+
function hasExtraneousKeys(obj: any, expectedKeys: string[]) {
16+
return Object.keys(obj).some((key) => !expectedKeys.includes(key));
17+
}
18+
1419
export type AIAggregation = {
15-
content?: {
20+
content: {
1621
aggregation?: {
17-
pipeline?: unknown;
22+
pipeline?: string;
1823
};
1924
};
2025
};
2126

27+
export function validateAIAggregationResponse(
28+
response: any
29+
): asserts response is AIAggregation {
30+
const { content } = response;
31+
32+
if (typeof content !== 'object' || content === null) {
33+
throw new Error('Unexpected response: expected content to be an object');
34+
}
35+
36+
if (hasExtraneousKeys(content, ['aggregation'])) {
37+
throw new Error('Unexpected keys in response: expected aggregation');
38+
}
39+
40+
if (content.aggregation && typeof content.aggregation.pipeline !== 'string') {
41+
// Compared to queries where we will always get the `query` field, for
42+
// aggregations backend deletes the whole `aggregation` key if pipeline is
43+
// empty, so we only validate `pipeline` key if `aggregation` key is present
44+
throw new Error(
45+
`Unexpected response: expected aggregation to be a string, got ${String(
46+
content.aggregation.pipeline
47+
)}`
48+
);
49+
}
50+
}
51+
2252
export type AIQuery = {
23-
content?: {
24-
query?: unknown;
53+
content: {
54+
query: Record<
55+
'filter' | 'project' | 'collation' | 'sort' | 'skip' | 'limit',
56+
string
57+
> & { aggregation?: { pipeline: string } };
2558
};
2659
};
60+
61+
export function validateAIQueryResponse(
62+
response: any
63+
): asserts response is AIQuery {
64+
const { content } = response ?? {};
65+
66+
if (typeof content !== 'object' || content === null) {
67+
throw new Error('Unexpected response: expected content to be an object');
68+
}
69+
70+
if (hasExtraneousKeys(content, ['query'])) {
71+
throw new Error('Unexpected keys in response: expected query');
72+
}
73+
74+
const { query } = content;
75+
76+
if (typeof query !== 'object' || query === null) {
77+
throw new Error('Unexpected response: expected query to be an object');
78+
}
79+
80+
if (
81+
hasExtraneousKeys(query, [
82+
'filter',
83+
'project',
84+
'collation',
85+
'sort',
86+
'skip',
87+
'limit',
88+
'aggregation',
89+
])
90+
) {
91+
throw new Error(
92+
'Unexpected keys in response: expected filter, project, collation, sort, skip, limit, aggregation'
93+
);
94+
}
95+
96+
for (const field of [
97+
'filter',
98+
'project',
99+
'collation',
100+
'sort',
101+
'skip',
102+
'limit',
103+
]) {
104+
if (query[field] && typeof query[field] !== 'string') {
105+
throw new Error(
106+
`Unexpected response: expected field ${field} to be a string, got ${util.inspect(
107+
query[field]
108+
)}`
109+
);
110+
}
111+
}
112+
113+
if (query.aggregation && typeof query.aggregation.pipeline !== 'string') {
114+
throw new Error(
115+
`Unexpected response: expected aggregation pipeline to be a string, got ${util.inspect(
116+
query.aggregation
117+
)}`
118+
);
119+
}
120+
}
121+
122+
export type AtlasServiceNetworkError = Error & { statusCode: number };

packages/compass-aggregations/src/modules/pipeline-builder/pipeline-ai.spec.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ describe('AIPipelineReducer', function () {
2828
it('should succeed', async function () {
2929
const mockAtlasService = {
3030
getAggregationFromUserInput: sandbox.stub().resolves({
31-
content: { aggregation: { pipeline: [{ $match: { _id: 1 } }] } },
31+
content: { aggregation: { pipeline: '[{ $match: { _id: 1 } }]' } },
3232
}),
3333
};
3434

0 commit comments

Comments
 (0)