@@ -15,8 +15,15 @@ import { CursorTimeoutContext } from '../cursor/abstract_cursor';
1515import { getSocks , type SocksLib } from '../deps' ;
1616import { MongoOperationTimeoutError } from '../error' ;
1717import { type MongoClient , type MongoClientOptions } from '../mongo_client' ;
18+ import { type Abortable } from '../mongo_types' ;
1819import { Timeout , type TimeoutContext , TimeoutError } from '../timeout' ;
19- import { BufferPool , MongoDBCollectionNamespace , promiseWithResolvers } from '../utils' ;
20+ import {
21+ addAbortListener ,
22+ BufferPool ,
23+ kDispose ,
24+ MongoDBCollectionNamespace ,
25+ promiseWithResolvers
26+ } from '../utils' ;
2027import { autoSelectSocketOptions , type DataKey } from './client_encryption' ;
2128import { MongoCryptError } from './errors' ;
2229import { type MongocryptdManager } from './mongocryptd_manager' ;
@@ -189,7 +196,7 @@ export class StateMachine {
189196 async execute (
190197 executor : StateMachineExecutable ,
191198 context : MongoCryptContext ,
192- timeoutContext ?: TimeoutContext
199+ options : { timeoutContext ?: TimeoutContext } & Abortable
193200 ) : Promise < Uint8Array > {
194201 const keyVaultNamespace = executor . _keyVaultNamespace ;
195202 const keyVaultClient = executor . _keyVaultClient ;
@@ -214,7 +221,7 @@ export class StateMachine {
214221 metaDataClient ,
215222 context . ns ,
216223 filter ,
217- timeoutContext
224+ options
218225 ) ;
219226 if ( collInfo ) {
220227 context . addMongoOperationResponse ( collInfo ) ;
@@ -235,9 +242,9 @@ export class StateMachine {
235242 // When we are using the shared library, we don't have a mongocryptd manager.
236243 const markedCommand : Uint8Array = mongocryptdManager
237244 ? await mongocryptdManager . withRespawn (
238- this . markCommand . bind ( this , mongocryptdClient , context . ns , command , timeoutContext )
245+ this . markCommand . bind ( this , mongocryptdClient , context . ns , command , options )
239246 )
240- : await this . markCommand ( mongocryptdClient , context . ns , command , timeoutContext ) ;
247+ : await this . markCommand ( mongocryptdClient , context . ns , command , options ) ;
241248
242249 context . addMongoOperationResponse ( markedCommand ) ;
243250 context . finishMongoOperation ( ) ;
@@ -246,12 +253,7 @@ export class StateMachine {
246253
247254 case MONGOCRYPT_CTX_NEED_MONGO_KEYS : {
248255 const filter = context . nextMongoOperation ( ) ;
249- const keys = await this . fetchKeys (
250- keyVaultClient ,
251- keyVaultNamespace ,
252- filter ,
253- timeoutContext
254- ) ;
256+ const keys = await this . fetchKeys ( keyVaultClient , keyVaultNamespace , filter , options ) ;
255257
256258 if ( keys . length === 0 ) {
257259 // See docs on EMPTY_V
@@ -273,7 +275,7 @@ export class StateMachine {
273275 }
274276
275277 case MONGOCRYPT_CTX_NEED_KMS : {
276- await Promise . all ( this . requests ( context , timeoutContext ) ) ;
278+ await Promise . all ( this . requests ( context , options ) ) ;
277279 context . finishKMSRequests ( ) ;
278280 break ;
279281 }
@@ -315,11 +317,13 @@ export class StateMachine {
315317 * @param kmsContext - A C++ KMS context returned from the bindings
316318 * @returns A promise that resolves when the KMS reply has be fully parsed
317319 */
318- async kmsRequest ( request : MongoCryptKMSRequest , timeoutContext ?: TimeoutContext ) : Promise < void > {
320+ async kmsRequest (
321+ request : MongoCryptKMSRequest ,
322+ options : { timeoutContext ?: TimeoutContext } & Abortable
323+ ) : Promise < void > {
319324 const parsedUrl = request . endpoint . split ( ':' ) ;
320325 const port = parsedUrl [ 1 ] != null ? Number . parseInt ( parsedUrl [ 1 ] , 10 ) : HTTPS_PORT ;
321- const socketOptions = autoSelectSocketOptions ( this . options . socketOptions || { } ) ;
322- const options : tls . ConnectionOptions & {
326+ const socketOptions : tls . ConnectionOptions & {
323327 host : string ;
324328 port : number ;
325329 autoSelectFamily ?: boolean ;
@@ -328,7 +332,7 @@ export class StateMachine {
328332 host : parsedUrl [ 0 ] ,
329333 servername : parsedUrl [ 0 ] ,
330334 port,
331- ...socketOptions
335+ ...autoSelectSocketOptions ( this . options . socketOptions || { } )
332336 } ;
333337 const message = request . message ;
334338 const buffer = new BufferPool ( ) ;
@@ -363,7 +367,7 @@ export class StateMachine {
363367 throw error ;
364368 }
365369 try {
366- await this . setTlsOptions ( providerTlsOptions , options ) ;
370+ await this . setTlsOptions ( providerTlsOptions , socketOptions ) ;
367371 } catch ( err ) {
368372 throw onerror ( err ) ;
369373 }
@@ -380,23 +384,25 @@ export class StateMachine {
380384 . once ( 'close' , ( ) => rejectOnNetSocketError ( onclose ( ) ) )
381385 . once ( 'connect' , ( ) => resolveOnNetSocketConnect ( ) ) ;
382386
387+ let abortListener ;
388+
383389 try {
384390 if ( this . options . proxyOptions && this . options . proxyOptions . proxyHost ) {
385391 const netSocketOptions = {
392+ ...socketOptions ,
386393 host : this . options . proxyOptions . proxyHost ,
387- port : this . options . proxyOptions . proxyPort || 1080 ,
388- ...socketOptions
394+ port : this . options . proxyOptions . proxyPort || 1080
389395 } ;
390396 netSocket . connect ( netSocketOptions ) ;
391397 await willConnect ;
392398
393399 try {
394400 socks ??= loadSocks ( ) ;
395- options . socket = (
401+ socketOptions . socket = (
396402 await socks . SocksClient . createConnection ( {
397403 existing_socket : netSocket ,
398404 command : 'connect' ,
399- destination : { host : options . host , port : options . port } ,
405+ destination : { host : socketOptions . host , port : socketOptions . port } ,
400406 proxy : {
401407 // host and port are ignored because we pass existing_socket
402408 host : 'iLoveJavaScript' ,
@@ -412,7 +418,7 @@ export class StateMachine {
412418 }
413419 }
414420
415- socket = tls . connect ( options , ( ) => {
421+ socket = tls . connect ( socketOptions , ( ) => {
416422 socket . write ( message ) ;
417423 } ) ;
418424
@@ -422,6 +428,11 @@ export class StateMachine {
422428 resolve
423429 } = promiseWithResolvers < void > ( ) ;
424430
431+ abortListener = addAbortListener ( options . signal , error => {
432+ destroySockets ( ) ;
433+ rejectOnTlsSocketError ( error ) ;
434+ } ) ;
435+
425436 socket
426437 . once ( 'error' , err => rejectOnTlsSocketError ( onerror ( err ) ) )
427438 . once ( 'close' , ( ) => rejectOnTlsSocketError ( onclose ( ) ) )
@@ -436,8 +447,11 @@ export class StateMachine {
436447 resolve ( ) ;
437448 }
438449 } ) ;
439- await ( timeoutContext ?. csotEnabled ( )
440- ? Promise . all ( [ willResolveKmsRequest , Timeout . expires ( timeoutContext ?. remainingTimeMS ) ] )
450+ await ( options . timeoutContext ?. csotEnabled ( )
451+ ? Promise . all ( [
452+ willResolveKmsRequest ,
453+ Timeout . expires ( options . timeoutContext ?. remainingTimeMS )
454+ ] )
441455 : willResolveKmsRequest ) ;
442456 } catch ( error ) {
443457 if ( error instanceof TimeoutError )
@@ -446,16 +460,17 @@ export class StateMachine {
446460 } finally {
447461 // There's no need for any more activity on this socket at this point.
448462 destroySockets ( ) ;
463+ abortListener ?. [ kDispose ] ( ) ;
449464 }
450465 }
451466
452- * requests ( context : MongoCryptContext , timeoutContext ?: TimeoutContext ) {
467+ * requests ( context : MongoCryptContext , options : { timeoutContext ?: TimeoutContext } & Abortable ) {
453468 for (
454469 let request = context . nextKMSRequest ( ) ;
455470 request != null ;
456471 request = context . nextKMSRequest ( )
457472 ) {
458- yield this . kmsRequest ( request , timeoutContext ) ;
473+ yield this . kmsRequest ( request , options ) ;
459474 }
460475 }
461476
@@ -516,14 +531,16 @@ export class StateMachine {
516531 client : MongoClient ,
517532 ns : string ,
518533 filter : Document ,
519- timeoutContext ?: TimeoutContext
534+ options : { timeoutContext ?: TimeoutContext } & Abortable
520535 ) : Promise < Uint8Array | null > {
521536 const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
522537
523538 const cursor = client . db ( db ) . listCollections ( filter , {
524539 promoteLongs : false ,
525540 promoteValues : false ,
526- timeoutContext : timeoutContext && new CursorTimeoutContext ( timeoutContext , Symbol ( ) )
541+ timeoutContext :
542+ options . timeoutContext && new CursorTimeoutContext ( options . timeoutContext , Symbol ( ) ) ,
543+ signal : options . signal
527544 } ) ;
528545
529546 // There is always exactly zero or one matching documents, so this should always exhaust the cursor
@@ -547,17 +564,30 @@ export class StateMachine {
547564 client : MongoClient ,
548565 ns : string ,
549566 command : Uint8Array ,
550- timeoutContext ?: TimeoutContext
567+ options : { timeoutContext ?: TimeoutContext } & Abortable
551568 ) : Promise < Uint8Array > {
552569 const { db } = MongoDBCollectionNamespace . fromString ( ns ) ;
553570 const bsonOptions = { promoteLongs : false , promoteValues : false } ;
554571 const rawCommand = deserialize ( command , bsonOptions ) ;
555572
573+ const commandOptions : {
574+ timeoutMS ?: number ;
575+ signal ?: AbortSignal ;
576+ } = {
577+ timeoutMS : undefined ,
578+ signal : undefined
579+ } ;
580+
581+ if ( options . timeoutContext ?. csotEnabled ( ) ) {
582+ commandOptions . timeoutMS = options . timeoutContext . remainingTimeMS ;
583+ }
584+ if ( options . signal ) {
585+ commandOptions . signal = options . signal ;
586+ }
587+
556588 const response = await client . db ( db ) . command ( rawCommand , {
557589 ...bsonOptions ,
558- ...( timeoutContext ?. csotEnabled ( )
559- ? { timeoutMS : timeoutContext ?. remainingTimeMS }
560- : undefined )
590+ ...commandOptions
561591 } ) ;
562592
563593 return serialize ( response , this . bsonOptions ) ;
@@ -575,17 +605,30 @@ export class StateMachine {
575605 client : MongoClient ,
576606 keyVaultNamespace : string ,
577607 filter : Uint8Array ,
578- timeoutContext ?: TimeoutContext
608+ options : { timeoutContext ?: TimeoutContext } & Abortable
579609 ) : Promise < Array < DataKey > > {
580610 const { db : dbName , collection : collectionName } =
581611 MongoDBCollectionNamespace . fromString ( keyVaultNamespace ) ;
582612
613+ const commandOptions : {
614+ timeoutContext ?: CursorTimeoutContext ;
615+ signal ?: AbortSignal ;
616+ } = {
617+ timeoutContext : undefined ,
618+ signal : undefined
619+ } ;
620+
621+ if ( options . timeoutContext != null ) {
622+ commandOptions . timeoutContext = new CursorTimeoutContext ( options . timeoutContext , Symbol ( ) ) ;
623+ }
624+ if ( options . signal != null ) {
625+ commandOptions . signal = options . signal ;
626+ }
627+
583628 return client
584629 . db ( dbName )
585630 . collection < DataKey > ( collectionName , { readConcern : { level : 'majority' } } )
586- . find ( deserialize ( filter ) , {
587- timeoutContext : timeoutContext && new CursorTimeoutContext ( timeoutContext , Symbol ( ) )
588- } )
631+ . find ( deserialize ( filter ) , commandOptions )
589632 . toArray ( ) ;
590633 }
591634}
0 commit comments