1
1
import Sinon from 'sinon' ;
2
2
import { expect } from 'chai' ;
3
- import { AtlasService } from './main' ;
3
+ import { AtlasService , throwIfNotOk } from './main' ;
4
4
5
5
describe ( 'AtlasServiceMain' , function ( ) {
6
6
const sandbox = Sinon . createSandbox ( ) ;
@@ -23,17 +23,22 @@ describe('AtlasServiceMain', function () {
23
23
24
24
AtlasService [ 'plugin' ] = mockOidcPlugin ;
25
25
26
+ const fetch = AtlasService [ 'fetch' ] ;
27
+ const apiBaseUrl = process . env . DEV_AI_QUERY_ENDPOINT ;
26
28
const issuer = process . env . COMPASS_OIDC_ISSUER ;
27
29
const clientId = process . env . COMPASS_CLIENT_ID ;
28
30
29
31
before ( function ( ) {
32
+ process . env . DEV_AI_QUERY_ENDPOINT = 'http://example.com' ;
30
33
process . env . COMPASS_OIDC_ISSUER = 'http://example.com' ;
31
34
process . env . COMPASS_CLIENT_ID = '1234abcd' ;
32
35
} ) ;
33
36
34
37
after ( function ( ) {
38
+ process . env . DEV_AI_QUERY_ENDPOINT = apiBaseUrl ;
35
39
process . env . COMPASS_OIDC_ISSUER = issuer ;
36
40
process . env . COMPASS_CLIENT_ID = clientId ;
41
+ AtlasService [ 'fetch' ] = fetch ;
37
42
} ) ;
38
43
39
44
afterEach ( function ( ) {
@@ -88,4 +93,185 @@ describe('AtlasServiceMain', function () {
88
93
expect ( err ) . to . have . property ( 'message' , 'COMPASS_CLIENT_ID is required' ) ;
89
94
}
90
95
} ) ;
96
+
97
+ describe ( 'getQueryFromUserPrompt' , function ( ) {
98
+ it ( 'makes a post request with the user prompt to the endpoint in the environment' , async function ( ) {
99
+ AtlasService [ 'fetch' ] = sandbox . stub ( ) . resolves ( {
100
+ ok : true ,
101
+ json ( ) {
102
+ return Promise . resolve ( {
103
+ content : { query : { find : { test : 'pineapple' } } } ,
104
+ } ) ;
105
+ } ,
106
+ } ) as any ;
107
+
108
+ const res = await AtlasService . getQueryFromUserPrompt ( {
109
+ userPrompt : 'test' ,
110
+ signal : new AbortController ( ) . signal ,
111
+ collectionName : 'jam' ,
112
+ schema : { _id : { types : [ { bsonType : 'ObjectId' } ] } } ,
113
+ sampleDocuments : [ { _id : 1234 } ] ,
114
+ } ) ;
115
+
116
+ const { args } = (
117
+ AtlasService [ 'fetch' ] as unknown as Sinon . SinonStub
118
+ ) . getCall ( 0 ) ;
119
+
120
+ expect ( AtlasService [ 'fetch' ] ) . to . have . been . calledOnce ;
121
+ expect ( args [ 0 ] ) . to . eq ( 'http://example.com/ai/api/v1/mql-query' ) ;
122
+ expect ( args [ 1 ] . body ) . to . eq (
123
+ '{"userPrompt":"test","collectionName":"jam","schema":{"_id":{"types":[{"bsonType":"ObjectId"}]}},"sampleDocuments":[{"_id":1234}]}'
124
+ ) ;
125
+ expect ( res ) . to . have . nested . property (
126
+ 'content.query.find.test' ,
127
+ 'pineapple'
128
+ ) ;
129
+ } ) ;
130
+
131
+ it ( 'uses the abort signal in the fetch request' , async function ( ) {
132
+ const c = new AbortController ( ) ;
133
+ c . abort ( ) ;
134
+ try {
135
+ await AtlasService . getQueryFromUserPrompt ( {
136
+ signal : c . signal ,
137
+ userPrompt : 'test' ,
138
+ collectionName : 'test.test' ,
139
+ } ) ;
140
+ expect . fail ( 'Expected getQueryFromUserPrompt to throw' ) ;
141
+ } catch ( err ) {
142
+ expect ( err ) . to . have . property ( 'message' , 'This operation was aborted' ) ;
143
+ }
144
+ } ) ;
145
+
146
+ it ( 'throws if the request would be too much for the ai' , async function ( ) {
147
+ try {
148
+ await AtlasService . getQueryFromUserPrompt ( {
149
+ userPrompt : 'test' ,
150
+ collectionName : 'test.test' ,
151
+ sampleDocuments : [ { test : '4' . repeat ( 60000 ) } ] ,
152
+ } ) ;
153
+ expect . fail ( 'Expected getQueryFromUserPrompt to throw' ) ;
154
+ } catch ( err ) {
155
+ expect ( err ) . to . have . property (
156
+ 'message' ,
157
+ 'Error: too large of a request to send to the ai. Please use a smaller prompt or collection with smaller documents.'
158
+ ) ;
159
+ }
160
+ } ) ;
161
+
162
+ it ( 'passes fewer documents if the request would be too much for the ai with all of the documents' , async function ( ) {
163
+ AtlasService [ 'fetch' ] = sandbox . stub ( ) . resolves ( {
164
+ ok : true ,
165
+ json ( ) {
166
+ return Promise . resolve ( { } ) ;
167
+ } ,
168
+ } ) as any ;
169
+
170
+ await AtlasService . getQueryFromUserPrompt ( {
171
+ userPrompt : 'test' ,
172
+ collectionName : 'test.test' ,
173
+ sampleDocuments : [
174
+ { a : '1' } ,
175
+ { a : '2' } ,
176
+ { a : '3' } ,
177
+ { a : '4' . repeat ( 50000 ) } ,
178
+ ] ,
179
+ } ) ;
180
+
181
+ const { args } = (
182
+ AtlasService [ 'fetch' ] as unknown as Sinon . SinonStub
183
+ ) . getCall ( 0 ) ;
184
+
185
+ expect ( AtlasService [ 'fetch' ] ) . to . have . been . calledOnce ;
186
+ expect ( args [ 1 ] . body ) . to . eq (
187
+ '{"userPrompt":"test","collectionName":"test.test","sampleDocuments":[{"a":"1"}]}'
188
+ ) ;
189
+ } ) ;
190
+
191
+ it ( 'throws the error' , async function ( ) {
192
+ AtlasService [ 'fetch' ] = sandbox . stub ( ) . resolves ( {
193
+ ok : false ,
194
+ status : 500 ,
195
+ statusText : 'Internal Server Error' ,
196
+ } ) as any ;
197
+
198
+ try {
199
+ await AtlasService . getQueryFromUserPrompt ( {
200
+ userPrompt : 'test' ,
201
+ collectionName : 'test.test' ,
202
+ } ) ;
203
+ expect . fail ( 'Expected getQueryFromUserPrompt to throw' ) ;
204
+ } catch ( err ) {
205
+ expect ( err ) . to . have . property ( 'message' , '500 Internal Server Error' ) ;
206
+ }
207
+ } ) ;
208
+
209
+ it ( 'should throw if DEV_AI_QUERY_ENDPOINT is not set' , async function ( ) {
210
+ delete process . env . DEV_AI_QUERY_ENDPOINT ;
211
+
212
+ try {
213
+ await AtlasService . getQueryFromUserPrompt ( {
214
+ userPrompt : 'test' ,
215
+ collectionName : 'test.test' ,
216
+ } ) ;
217
+ expect . fail ( 'Expected AtlasService.signIn() to throw' ) ;
218
+ } catch ( err ) {
219
+ expect ( err ) . to . have . property (
220
+ 'message' ,
221
+ 'No AI Query endpoint to fetch. Please set the environment variable `DEV_AI_QUERY_ENDPOINT`'
222
+ ) ;
223
+ }
224
+ } ) ;
225
+ } ) ;
226
+
227
+ describe ( 'throwIfNotOk' , function ( ) {
228
+ it ( 'should not throw if res is ok' , async function ( ) {
229
+ await throwIfNotOk ( {
230
+ ok : true ,
231
+ status : 200 ,
232
+ statusText : 'OK' ,
233
+ json ( ) {
234
+ return Promise . resolve ( { } ) ;
235
+ } ,
236
+ } ) ;
237
+ } ) ;
238
+
239
+ it ( 'should throw network error if res is not ok' , async function ( ) {
240
+ try {
241
+ await throwIfNotOk ( {
242
+ ok : false ,
243
+ status : 500 ,
244
+ statusText : 'Whoops' ,
245
+ json ( ) {
246
+ return Promise . resolve ( { } ) ;
247
+ } ,
248
+ } ) ;
249
+ expect . fail ( 'Expected throwIfNotOk to throw' ) ;
250
+ } catch ( err ) {
251
+ expect ( err ) . to . have . property ( 'name' , 'NetworkError' ) ;
252
+ expect ( err ) . to . have . property ( 'message' , '500 Whoops' ) ;
253
+ }
254
+ } ) ;
255
+
256
+ it ( 'should try to parse AIError from body and throw it' , async function ( ) {
257
+ try {
258
+ await throwIfNotOk ( {
259
+ ok : false ,
260
+ status : 500 ,
261
+ statusText : 'Whoops' ,
262
+ json ( ) {
263
+ return Promise . resolve ( {
264
+ name : 'AIError' ,
265
+ errorMessage : 'tortillas' ,
266
+ codeName : 'ExampleCode' ,
267
+ } ) ;
268
+ } ,
269
+ } ) ;
270
+ expect . fail ( 'Expected throwIfNotOk to throw' ) ;
271
+ } catch ( err ) {
272
+ expect ( err ) . to . have . property ( 'name' , 'Error' ) ;
273
+ expect ( err ) . to . have . property ( 'message' , 'ExampleCode: tortillas' ) ;
274
+ }
275
+ } ) ;
276
+ } ) ;
91
277
} ) ;
0 commit comments