Skip to content

Commit 5752f19

Browse files
committed
Introduce wasm/webgpu inference chain to prevent "Session already started" errors
1 parent 05924f9 commit 5752f19

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

src/models.js

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ async function getSession(pretrained_model_name_or_path, fileName, options) {
237237
const session_config = {
238238
dtype: selectedDtype,
239239
kv_cache_dtype,
240+
device: selectedDevice,
240241
}
241242

242243
// Construct the model file name
@@ -417,6 +418,10 @@ function validateInputs(session, inputs) {
417418
return checkedInputs;
418419
}
419420

421+
// Currently, Transformers.js doesn't support simultaneous execution of sessions in WASM/WebGPU.
422+
// For this reason, we need to chain the inference calls (otherwise we get "Error: Session already started").
423+
let webInferenceChain = Promise.resolve();
424+
420425
/**
421426
* Executes an InferenceSession using the specified inputs.
422427
* NOTE: `inputs` must contain at least the input names of the model.
@@ -433,17 +438,28 @@ async function sessionRun(session, inputs) {
433438
try {
434439
// pass the original ort tensor
435440
const ortFeed = Object.fromEntries(Object.entries(checkedInputs).map(([k, v]) => [k, v.ort_tensor]));
436-
let output = await session.run(ortFeed);
437-
output = replaceTensors(output);
438-
return output;
441+
const run = () => session.run(ortFeed);
442+
const output = await ((apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV)
443+
? (webInferenceChain = webInferenceChain.then(run))
444+
: run());
445+
return replaceTensors(output);
439446
} catch (e) {
440447
// Error messages can be long (nested) and uninformative. For this reason,
441448
// we apply minor formatting to show the most important information
442449
const formatted = Object.fromEntries(Object.entries(checkedInputs)
443-
.map(([k, { type, dims, data }]) => [k, {
450+
.map(([k, tensor]) => {
444451
// Extract these properties from the underlying ORT tensor
445-
type, dims, data,
446-
}]));
452+
const unpacked = {
453+
type: tensor.type,
454+
dims: tensor.dims,
455+
location: tensor.location,
456+
}
457+
if (unpacked.location !== "gpu-buffer") {
458+
// Only return the data if it's not a GPU buffer
459+
unpacked.data = tensor.data;
460+
}
461+
return [k, unpacked];
462+
}));
447463

448464
// This usually occurs when the inputs are of the wrong type.
449465
console.error(`An error occurred during model execution: "${e}".`);
@@ -5207,7 +5223,7 @@ export class RTDetrV2ForObjectDetection extends RTDetrV2PreTrainedModel {
52075223
}
52085224
}
52095225

5210-
export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput {}
5226+
export class RTDetrV2ObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
52115227
//////////////////////////////////////////////////
52125228

52135229
//////////////////////////////////////////////////
@@ -5222,7 +5238,7 @@ export class RFDetrForObjectDetection extends RFDetrPreTrainedModel {
52225238
}
52235239
}
52245240

5225-
export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput {}
5241+
export class RFDetrObjectDetectionOutput extends RTDetrObjectDetectionOutput { }
52265242
//////////////////////////////////////////////////
52275243

52285244
//////////////////////////////////////////////////

0 commit comments

Comments
 (0)