@@ -225,7 +225,7 @@ export class CanvasStateApiModule extends CanvasModuleBase {
225225 * controller.abort();
226226 * ```
227227 */
228- runGraphAndReturnImageOutput = async ( arg : {
228+ runGraphAndReturnImageOutput = ( arg : {
229229 graph : Graph ;
230230 outputNodeId : string ;
231231 destination ?: string ;
@@ -268,27 +268,12 @@ export class CanvasStateApiModule extends CanvasModuleBase {
268268 }
269269 } ;
270270
271- /**
272- * First, enqueue the graph - we need the `batch_id` to cancel the graph. But to get the `batch_id`, we need to
273- * `await` the request. You might be tempted to `await` the request inside the result promise, but we should not
274- * `await` inside a promise executor.
275- *
276- * See: https://eslint.org/docs/latest/rules/no-async-promise-executor
277- */
278- const enqueueRequest = this . store . dispatch (
279- queueApi . endpoints . enqueueBatch . initiate ( batch , {
280- // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
281- // updates.
282- fixedCacheKey : 'enqueueBatch' ,
283- // We do not need RTK to track this request in the store
284- track : false ,
285- } )
286- ) ;
287-
288- // The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect.
289- // TODO(psyche): Fix the OpenAPI schema.
290- const { batch_id } = ( await enqueueRequest . unwrap ( ) ) . batch ;
291- assert ( batch_id , 'Enqueue result is missing batch_id' ) ;
271+ // There's a bit of a catch-22 here: we need to set the cancelGraph callback before we enqueue the graph, but we
272+ // can't set it until we have the batch_id from the enqueue request. So we'll set a dummy function here and update
273+ // it later.
274+ let cancelGraph : ( ) => void = ( ) => {
275+ this . log . warn ( 'cancelGraph called before cancelGraph is set' ) ;
276+ } ;
292277
293278 const resultPromise = new Promise < ImageDTO > ( ( resolve , reject ) => {
294279 const invocationCompleteHandler = async ( event : S [ 'InvocationCompleteEvent' ] ) => {
@@ -357,6 +342,36 @@ export class CanvasStateApiModule extends CanvasModuleBase {
357342 }
358343 } ;
359344
345+ // We are ready to enqueue the graph
346+ const enqueueRequest = this . store . dispatch (
347+ queueApi . endpoints . enqueueBatch . initiate ( batch , {
348+ // Use the same cache key for all enqueueBatch requests, so that all consumers of this query get the same status
349+ // updates.
350+ fixedCacheKey : 'enqueueBatch' ,
351+ // We do not need RTK to track this request in the store
352+ track : false ,
353+ } )
354+ ) ;
355+
356+ // Enqueue the graph and get the batch_id, updating the cancel graph callack. We need to do this in a .then() block
357+ // instead of awaiting the promise to avoid await-ing in a promise executor. Also need to catch any errors.
358+ enqueueRequest
359+ . unwrap ( )
360+ . then ( ( data ) => {
361+ // The `batch_id` should _always_ be present - the OpenAPI schema from which the types are generated is incorrect.
362+ // TODO(psyche): Fix the OpenAPI schema.
363+ const batch_id = data . batch . batch_id ;
364+ assert ( batch_id , 'Enqueue result is missing batch_id' ) ;
365+ cancelGraph = ( ) => {
366+ this . store . dispatch (
367+ queueApi . endpoints . cancelByBatchIds . initiate ( { batch_ids : [ batch_id ] } , { track : false } )
368+ ) ;
369+ } ;
370+ } )
371+ . catch ( ( error ) => {
372+ reject ( error ) ;
373+ } ) ;
374+
360375 this . manager . socket . on ( 'invocation_complete' , invocationCompleteHandler ) ;
361376 this . manager . socket . on ( 'queue_item_status_changed' , queueItemStatusChangedHandler ) ;
362377
@@ -365,10 +380,6 @@ export class CanvasStateApiModule extends CanvasModuleBase {
365380 this . manager . socket . off ( 'queue_item_status_changed' , queueItemStatusChangedHandler ) ;
366381 } ;
367382
368- const cancelGraph = ( ) => {
369- this . store . dispatch ( queueApi . endpoints . cancelByBatchIds . initiate ( { batch_ids : [ batch_id ] } , { track : false } ) ) ;
370- } ;
371-
372383 if ( timeout ) {
373384 timeoutId = window . setTimeout ( ( ) => {
374385 this . log . trace ( 'Graph canceled by timeout' ) ;
0 commit comments