@@ -237,6 +237,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
237237 const session_config = {
238238 dtype : selectedDtype ,
239239 kv_cache_dtype,
240+ device : selectedDevice ,
240241 }
241242
242243 // Construct the model file name
@@ -417,6 +418,10 @@ function validateInputs(session, inputs) {
417418 return checkedInputs ;
418419}
419420
421+ // Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
422+ // For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
423+ let webInferenceChain = Promise . resolve ( ) ;
424+
420425/**
421426 * Executes an InferenceSession using the specified inputs.
422427 * NOTE: `inputs` must contain at least the input names of the model.
@@ -433,17 +438,28 @@ async function sessionRun(session, inputs) {
433438 try {
434439 // pass the original ort tensor
435440 const ortFeed = Object . fromEntries ( Object . entries ( checkedInputs ) . map ( ( [ k , v ] ) => [ k , v . ort_tensor ] ) ) ;
436- let output = await session . run ( ortFeed ) ;
437- output = replaceTensors ( output ) ;
438- return output ;
441+ const run = ( ) => session . run ( ortFeed ) ;
442+ const output = await ( ( apis . IS_BROWSER_ENV || apis . IS_WEBWORKER_ENV )
443+ ? ( webInferenceChain = webInferenceChain . then ( run ) )
444+ : run ( ) ) ;
445+ return replaceTensors ( output ) ;
439446 } catch ( e ) {
440447 // Error messages can be long (nested) and uninformative. For this reason,
441448 // we apply minor formatting to show the most important information
442449 const formatted = Object . fromEntries ( Object . entries ( checkedInputs )
443- . map ( ( [ k , { type , dims , data } ] ) => [ k , {
450+ . map ( ( [ k , tensor ] ) => {
444451 // Extract these properties from the underlying ORT tensor
445- type, dims, data,
446- } ] ) ) ;
452+ const unpacked = {
453+ type : tensor . type ,
454+ dims : tensor . dims ,
455+ location : tensor . location ,
456+ }
457+ if ( unpacked . location !== "gpu-buffer" ) {
458+ // Only return the data if it's not a GPU buffer
459+ unpacked . data = tensor . data ;
460+ }
461+ return [ k , unpacked ] ;
462+ } ) ) ;
447463
448464 // This usually occurs when the inputs are of the wrong type.
449465 console . error ( `An error occurred during model execution: "${ e } ".` ) ;
@@ -5223,7 +5239,7 @@ export class RTDetrV2ForObjectDetection extends RTDetrV2PreTrainedModel {
52235239 }
52245240}
52255241
5226- export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
5242+ export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
52275243//////////////////////////////////////////////////
52285244
52295245//////////////////////////////////////////////////
@@ -5238,7 +5254,7 @@ export class RFDetrForObjectDetection extends RFDetrPreTrainedModel {
52385254 }
52395255}
52405256
5241- export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
5257+ export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
52425258//////////////////////////////////////////////////
52435259
52445260//////////////////////////////////////////////////
0 commit comments