Skip to content

Commit 1088cc5

Browse files
authored
Support simultaneous tensor op execution (#1162)
1 parent 142f6e1 commit 1088cc5

File tree

1 file changed

+9
-1
lines changed

1 file changed

+9
-1
lines changed

src/ops/registry.js

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import { createInferenceSession, isONNXProxy } from "../backends/onnx.js";
22
import { Tensor } from "../utils/tensor.js";
3+
import { apis } from "../env.js";
34

5+
const IS_WEB_ENV = apis.IS_BROWSER_ENV || apis.IS_WEBWORKER_ENV;
46
/**
57
* Asynchronously creates a wrapper function for running an ONNX inference session.
68
*
@@ -16,10 +18,16 @@ const wrap = async (session_bytes, session_options, names) => {
1618
const session = await createInferenceSession(
1719
new Uint8Array(session_bytes), session_options,
1820
);
21+
22+
/** @type {Promise<any>} */
23+
let chain = Promise.resolve();
24+
1925
return /** @type {any} */(async (/** @type {Record<string, Tensor>} */ inputs) => {
2026
const proxied = isONNXProxy();
2127
const ortFeed = Object.fromEntries(Object.entries(inputs).map(([k, v]) => [k, (proxied ? v.clone() : v).ort_tensor]));
22-
const outputs = await session.run(ortFeed);
28+
29+
// When running in-browser via WASM, we need to chain calls to session.run to avoid "Error: Session already started"
30+
const outputs = await (chain = IS_WEB_ENV ? chain.then(() => session.run(ortFeed)) : session.run(ortFeed));
2331

2432
if (Array.isArray(names)) {
2533
return names.map((n) => new Tensor(outputs[n]));

0 commit comments

Comments
 (0)