Skip to content

Commit b5ef835

Browse files
kungfoomanxenova
andauthored
Fix NaNs when using ORT proxy (#404)
* Move tensor clone for Worker ownership NaN issue * Update src/models.js - Use conditional operator Co-authored-by: Joshua Lochner <[email protected]> * Update src/models.js - Object.create(null) Co-authored-by: Joshua Lochner <[email protected]> * tensor.js: remove "Object" type to fix types (since ONNX exports correct type now) * models.js / validateInputs(): Remove promise/await because it is not needed Use "tensor instanceof Tensor" check because otherwise validateInputs() thinks it has an input even if it doesn't * Fix JSDoc * Update JSDoc --------- Co-authored-by: Joshua Lochner <[email protected]>
1 parent ac0096e commit b5ef835

File tree

3 files changed

+24
-13
lines changed

3 files changed

+24
-13
lines changed

src/backends/onnx.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import * as ONNX_NODE from 'onnxruntime-node';
2222
import * as ONNX_WEB from 'onnxruntime-web';
2323

24-
/** @type {module} The ONNX runtime module. */
24+
/** @type {import('onnxruntime-web')} The ONNX runtime module. */
2525
export let ONNX;
2626

2727
export const executionProviders = [

src/models.js

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,9 @@ import {
8686

8787
import { executionProviders, ONNX } from './backends/onnx.js';
8888
import { medianFilter } from './transformers.js';
89-
const { InferenceSession, Tensor: ONNXTensor } = ONNX;
89+
const { InferenceSession, Tensor: ONNXTensor, env } = ONNX;
90+
91+
/** @typedef {import('onnxruntime-web').InferenceSession} InferenceSession */
9092

9193
//////////////////////////////////////////////////
9294
// Model types: used internally
@@ -146,21 +148,31 @@ async function constructSession(pretrained_model_name_or_path, fileName, options
146148
/**
147149
* Validate model inputs
148150
* @param {InferenceSession} session The InferenceSession object that will be run.
149-
* @param {Object} inputs The inputs to check.
150-
* @returns {Promise<Object>} A Promise that resolves to the checked inputs.
151+
* @param {Record<string, Tensor>} inputs The inputs to check.
152+
* @returns {Record<string, Tensor>} The checked inputs.
151153
* @throws {Error} If any inputs are missing.
152154
* @private
153155
*/
154-
async function validateInputs(session, inputs) {
155-
// NOTE: Only create a shallow copy
156-
const checkedInputs = {};
156+
function validateInputs(session, inputs) {
157+
/**
158+
* NOTE: Create either a shallow or deep copy based on `onnx.wasm.proxy`
159+
* @type {Record<string, Tensor>}
160+
*/
161+
const checkedInputs = Object.create(null);
157162
const missingInputs = [];
158-
for (let inputName of session.inputNames) {
159-
if (inputs[inputName] === undefined) {
163+
for (const inputName of session.inputNames) {
164+
const tensor = inputs[inputName];
165+
// Rare case where one of the model's input names corresponds to a built-in
166+
// object name (e.g., toString), which would cause a simple (!tensor) check to fail,
167+
// because it's not undefined but a function.
168+
if (!(tensor instanceof Tensor)) {
160169
missingInputs.push(inputName);
161-
} else {
162-
checkedInputs[inputName] = inputs[inputName];
170+
continue;
163171
}
172+
// NOTE: When `env.wasm.proxy is true` the tensor is moved across the Worker
173+
// boundary, transferring ownership to the worker and invalidating the tensor.
174+
// So, in this case, we simply sacrifice a clone for it.
175+
checkedInputs[inputName] = env.wasm.proxy ? tensor.clone() : tensor;
164176
}
165177
if (missingInputs.length > 0) {
166178
throw new Error(
@@ -191,7 +203,7 @@ async function validateInputs(session, inputs) {
191203
* @private
192204
*/
193205
async function sessionRun(session, inputs) {
194-
const checkedInputs = await validateInputs(session, inputs);
206+
const checkedInputs = validateInputs(session, inputs);
195207
try {
196208
let output = await session.run(checkedInputs);
197209
output = replaceTensors(output);

src/utils/tensor.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@ const DataTypeMap = new Map([
3434
* @typedef {import('./maths.js').AnyTypedArray | any[]} DataArray
3535
*/
3636

37-
/** @type {Object} */
3837
const ONNXTensor = ONNX.Tensor;
3938

4039
export class Tensor extends ONNXTensor {

0 commit comments

Comments
 (0)