Skip to content

Commit eceae8b

Browse files
authored
[WebNN/WebGPU JS] Fix shared Module methods overriding each other (microsoft#23998)
- Renamed all conflicting WebNN methods from `jsep*` to `webnn*`. - WebNN doesn't need flush(), therefore it doesn't need to set `jsepBackend`. This PR addresses issue microsoft/webnn-developer-preview#78
1 parent d98046b commit eceae8b

File tree

8 files changed

+89
-71
lines changed

8 files changed

+89
-71
lines changed

js/web/lib/wasm/wasm-core-impl.ts

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -309,12 +309,12 @@ export const createSession = async (
309309
if (context) {
310310
wasm.currentContext = context as MLContext;
311311
} else if (gpuDevice) {
312-
wasm.currentContext = await wasm.jsepCreateMLContext!(gpuDevice);
312+
wasm.currentContext = await wasm.webnnCreateMLContext!(gpuDevice);
313313
} else {
314-
wasm.currentContext = await wasm.jsepCreateMLContext!({ deviceType, powerPreference });
314+
wasm.currentContext = await wasm.webnnCreateMLContext!({ deviceType, powerPreference });
315315
}
316316
} else {
317-
wasm.currentContext = await wasm.jsepCreateMLContext!();
317+
wasm.currentContext = await wasm.webnnCreateMLContext!();
318318
}
319319
break;
320320
}
@@ -330,7 +330,7 @@ export const createSession = async (
330330

331331
// clear current MLContext after session creation
332332
if (wasm.currentContext) {
333-
wasm.jsepRegisterMLContext!(sessionHandle, wasm.currentContext);
333+
wasm.webnnRegisterMLContext!(sessionHandle, wasm.currentContext);
334334
wasm.currentContext = undefined;
335335
wasm.shouldTransferToMLTensor = true;
336336
}
@@ -454,6 +454,7 @@ export const releaseSession = (sessionId: number): void => {
454454
}
455455

456456
wasm.jsepOnReleaseSession?.(sessionId);
457+
wasm.webnnOnReleaseSession?.(sessionId);
457458
wasm.webgpuOnReleaseSession?.(sessionId);
458459

459460
inputNamesUTF8Encoded.forEach((buf) => wasm._OrtFree(buf));
@@ -520,7 +521,7 @@ export const prepareInputOutputTensor = async (
520521
const mlTensor = tensor[2].mlTensor as MLTensor;
521522
dataByteLength = calculateTensorSizeInBytes(tensorDataTypeStringToEnum(dataType), dims)!;
522523

523-
const registerMLTensor = wasm.jsepRegisterMLTensor;
524+
const registerMLTensor = wasm.webnnRegisterMLTensor;
524525
if (!registerMLTensor) {
525526
throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.');
526527
}
@@ -540,7 +541,7 @@ export const prepareInputOutputTensor = async (
540541
wasm.setValue(rawData + i * ptrSize, allocWasmString(data[i], allocs), '*');
541542
}
542543
} else {
543-
const isGraphInput = wasm.jsepIsGraphInput;
544+
const isGraphInput = wasm.webnnIsGraphInput;
544545
if (dataType !== 'string' && isGraphInput) {
545546
const tensorNameUTF8 = wasm._OrtGetInputName(sessionId, index);
546547
const tensorName = wasm.UTF8ToString(tensorNameUTF8);
@@ -549,8 +550,8 @@ export const prepareInputOutputTensor = async (
549550
const dataTypeEnum = tensorDataTypeStringToEnum(dataType);
550551
dataByteLength = calculateTensorSizeInBytes(dataTypeEnum, dims)!;
551552
actualLocation = 'ml-tensor';
552-
const createTemporaryTensor = wasm.jsepCreateTemporaryTensor;
553-
const uploadTensor = wasm.jsepUploadTensor;
553+
const createTemporaryTensor = wasm.webnnCreateTemporaryTensor;
554+
const uploadTensor = wasm.webnnUploadTensor;
554555
if (!createTemporaryTensor || !uploadTensor) {
555556
throw new Error('Tensor location "ml-tensor" is not supported without using WebNN.');
556557
}
@@ -722,6 +723,7 @@ export const run = async (
722723
}
723724

724725
wasm.jsepOnRunStart?.(sessionHandle);
726+
wasm.webnnOnRunStart?.(sessionHandle);
725727

726728
let errorCode: number;
727729
if (!BUILD_DEFS.DISABLE_JSEP && ioBindingState) {
@@ -862,8 +864,8 @@ export const run = async (
862864
]);
863865
}
864866
} else if (preferredLocation === 'ml-tensor' && size > 0) {
865-
const ensureTensor = wasm.jsepEnsureTensor;
866-
const isInt64Supported = wasm.jsepIsInt64Supported;
867+
const ensureTensor = wasm.webnnEnsureTensor;
868+
const isInt64Supported = wasm.webnnIsInt64Supported;
867869
if (!ensureTensor || !isInt64Supported) {
868870
throw new Error('preferredLocation "ml-tensor" is not supported without using WebNN.');
869871
}
@@ -890,9 +892,9 @@ export const run = async (
890892
dims,
891893
{
892894
mlTensor,
893-
download: wasm.jsepCreateMLTensorDownloader!(dataOffset, type),
895+
download: wasm.webnnCreateMLTensorDownloader!(dataOffset, type),
894896
dispose: () => {
895-
wasm.jsepReleaseTensorId!(dataOffset);
897+
wasm.webnnReleaseTensorId!(dataOffset);
896898
wasm._OrtReleaseTensor(tensor);
897899
},
898900
},
@@ -915,7 +917,7 @@ export const run = async (
915917
if (!keepOutputTensor) {
916918
wasm._OrtReleaseTensor(tensor);
917919
}
918-
wasm.jsepOnRunEnd?.(sessionHandle);
920+
wasm.webnnOnRunEnd?.(sessionHandle);
919921
}
920922
}
921923

js/web/lib/wasm/wasm-types.ts

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -156,31 +156,45 @@ export declare namespace JSEP {
156156
*/
157157
shouldTransferToMLTensor: boolean;
158158

159+
/**
160+
* [exported from pre-jsep.js] Called when InferenceSession.run started. This function will be called before
161+
* _OrtRun[WithBinding]() is called.
162+
* @param sessionId - specify the session ID.
163+
*/
164+
webnnOnRunStart: (sessionId: number) => void;
165+
/**
166+
* [exported from pre-jsep.js] Release a session. This function will be called before _OrtReleaseSession() is
167+
* called.
168+
* @param sessionId - specify the session ID.
169+
* @returns
170+
*/
171+
webnnOnReleaseSession: (sessionId: number) => void;
172+
159173
/**
160174
* [exported from pre-jsep.js] Called when InferenceSession.run finished. This function will be called after
161175
* _OrtRun[WithBinding]() is called.
162176
* @param sessionId - specify the session ID.
163177
*/
164-
jsepOnRunEnd: (sessionId: number) => void;
178+
webnnOnRunEnd: (sessionId: number) => void;
165179

166180
/**
167181
* [exported from pre-jsep.js] Register MLContext for a session.
168182
* @param sessionId - specify the session ID.
169183
* @param context - specify the MLContext.
170184
* @returns
171185
*/
172-
jsepRegisterMLContext: (sessionId: number, context: MLContext) => void;
186+
webnnRegisterMLContext: (sessionId: number, context: MLContext) => void;
173187
/**
174188
* [exported from pre-jsep.js] Reserve a MLTensor ID attached to the current session.
175189
* @returns the MLTensor ID.
176190
*/
177-
jsepReserveTensorId: () => number;
191+
webnnReserveTensorId: () => number;
178192
/**
179193
* [exported from pre-jsep.js] Release an MLTensor ID from use and destroys underlying MLTensor if no longer in use.
180194
* @param tensorId - specify the MLTensor ID.
181195
* @returns
182196
*/
183-
jsepReleaseTensorId: (tensorId: number) => void;
197+
webnnReleaseTensorId: (tensorId: number) => void;
184198
/**
185199
* [exported from pre-jsep.js] Ensure that an MLTensor of a given type and shape exists for a MLTensor ID.
186200
* @param sessionId - specify the session ID or current active session ID if undefined.
@@ -190,7 +204,7 @@ export declare namespace JSEP {
190204
* @param copyOld - specify whether to copy the old tensor if a new tensor was created.
191205
* @returns the MLTensor associated with the tensor ID.
192206
*/
193-
jsepEnsureTensor: (
207+
webnnEnsureTensor: (
194208
sessionId: number | undefined,
195209
tensorId: number,
196210
dataType: DataType,
@@ -203,20 +217,20 @@ export declare namespace JSEP {
203217
* @param data - specify the data to upload. It can be a TensorProto::data_type or a WebNN MLOperandDataType.
204218
* @returns
205219
*/
206-
jsepUploadTensor: (tensorId: number, data: Uint8Array) => void;
220+
webnnUploadTensor: (tensorId: number, data: Uint8Array) => void;
207221
/**
208222
* [exported from pre-jsep.js] Download data from an MLTensor.
209223
* @param tensorId - specify the MLTensor ID.
210224
* @returns the downloaded data.
211225
*/
212-
jsepDownloadTensor: (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise<undefined>;
226+
webnnDownloadTensor: (tensorId: number, dstBuffer: ArrayBufferView | ArrayBuffer) => Promise<undefined>;
213227
/**
214228
* [exported from pre-jsep.js] Creates a downloader function to download data from an MLTensor.
215229
* @param tensorId - specify the MLTensor ID.
216230
* @param type - specify the data type.
217231
* @returns the downloader function.
218232
*/
219-
jsepCreateMLTensorDownloader: (
233+
webnnCreateMLTensorDownloader: (
220234
tensorId: number,
221235
type: Tensor.MLTensorDataTypes,
222236
) => () => Promise<Tensor.DataTypeMap[Tensor.MLTensorDataTypes]>;
@@ -228,7 +242,7 @@ export declare namespace JSEP {
228242
* @param dimensions - specify the dimensions.
229243
* @returns the MLTensor ID for the external MLTensor.
230244
*/
231-
jsepRegisterMLTensor: (
245+
webnnRegisterMLTensor: (
232246
sessionId: number,
233247
tensor: MLTensor,
234248
onnxDataType: DataType,
@@ -240,7 +254,7 @@ export declare namespace JSEP {
240254
* @param optionsOrGpuDevice - specify the options or GPUDevice.
241255
* @returns
242256
*/
243-
jsepCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise<MLContext>;
257+
webnnCreateMLContext(optionsOrGpuDevice?: MLContextOptions | GPUDevice): Promise<MLContext>;
244258

245259
/**
246260
* [exported from pre-jsep.js] Register a WebNN Constant operand from external data.
@@ -252,7 +266,7 @@ export declare namespace JSEP {
252266
* @param shouldConvertInt64ToInt32 - specify whether to convert int64 to int32.
253267
* @returns the WebNN Constant operand for the specified external data.
254268
*/
255-
jsepRegisterMLConstant(
269+
webnnRegisterMLConstant(
256270
externalFilePath: string,
257271
dataOffset: number,
258272
dataLength: number,
@@ -265,28 +279,28 @@ export declare namespace JSEP {
265279
* [exported from pre-jsep.js] Register a WebNN graph input.
266280
* @param inputName - specify the input name.
267281
*/
268-
jsepRegisterGraphInput: (inputName: string) => void;
282+
webnnRegisterGraphInput: (inputName: string) => void;
269283
/**
270284
* [exported from pre-jsep.js] Check if a graph input is a WebNN graph input.
271285
* @param sessionId - specify the session ID.
272286
* @param inputName - specify the input name.
273287
* @returns whether the input is a WebNN graph input.
274288
*/
275-
jsepIsGraphInput: (sessionId: number, inputName: string) => boolean;
289+
webnnIsGraphInput: (sessionId: number, inputName: string) => boolean;
276290
/**
277291
* [exported from pre-jsep.js] Create a temporary MLTensor for a session.
278292
* @param sessionId - specify the session ID.
279293
* @param dataType - specify the data type.
280294
* @param shape - specify the shape.
281295
* @returns the MLTensor ID for the temporary MLTensor.
282296
*/
283-
jsepCreateTemporaryTensor: (sessionId: number, dataType: DataType, shape: readonly number[]) => Promise<number>;
297+
webnnCreateTemporaryTensor: (sessionId: number, dataType: DataType, shape: readonly number[]) => Promise<number>;
284298
/**
285299
* [exported from pre-jsep.js] Check if a session's associated WebNN Context supports int64.
286300
* @param sessionId - specify the session ID.
287301
* @returns whether the WebNN Context supports int64.
288302
*/
289-
jsepIsInt64Supported: (sessionId: number) => boolean;
303+
webnnIsInt64Supported: (sessionId: number) => boolean;
290304
}
291305
}
292306

onnxruntime/core/providers/webnn/allocator.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ void* WebNNTensorAllocator::Alloc(size_t size) {
1616
// We don't need to transfer the tensor to an MLTensor, so we don't need to allocate an MLTensor id.
1717
return nullptr;
1818
}
19-
void* p = EM_ASM_PTR({ return Module.jsepReserveTensorId(); });
19+
void* p = EM_ASM_PTR({ return Module.webnnReserveTensorId(); });
2020
allocations_[p] = size;
2121
stats_.num_allocs++;
2222
stats_.bytes_in_use += SafeInt<int64_t>(size);
@@ -27,7 +27,7 @@ void WebNNTensorAllocator::Free(void* p) {
2727
if (p == nullptr) {
2828
return;
2929
}
30-
EM_ASM({ Module.jsepReleaseTensorId($0); }, p);
30+
EM_ASM({ Module.webnnReleaseTensorId($0); }, p);
3131
size_t size = allocations_[p];
3232
stats_.bytes_in_use -= size;
3333
allocations_.erase(p);

onnxruntime/core/providers/webnn/builders/model.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -157,15 +157,15 @@ onnxruntime::common::Status Model::Compute(const InlinedHashMap<std::string, Onn
157157

158158
onnxruntime::common::Status Model::Dispatch(const InlinedHashMap<std::string, OnnxTensorData>& inputs,
159159
const InlinedHashMap<std::string, OnnxTensorData>& outputs) {
160-
auto jsepEnsureTensor = emscripten::val::module_property("jsepEnsureTensor");
160+
auto webnnEnsureTensor = emscripten::val::module_property("webnnEnsureTensor");
161161
auto promises = emscripten::val::array();
162162
for (const auto& [_, tensor] : inputs) {
163163
emscripten::val shape = emscripten::val::array();
164164
for (const auto& dim : tensor.tensor_info.shape) {
165165
uint32_t dim_val = SafeInt<uint32_t>(dim);
166166
shape.call<void>("push", dim_val);
167167
}
168-
auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape, true);
168+
auto ml_tensor = webnnEnsureTensor(emscripten::val::undefined(), reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape, true);
169169
promises.call<void>("push", ml_tensor);
170170
}
171171
for (const auto& [_, tensor] : outputs) {
@@ -174,7 +174,7 @@ onnxruntime::common::Status Model::Dispatch(const InlinedHashMap<std::string, On
174174
uint32_t dim_val = SafeInt<uint32_t>(dim);
175175
shape.call<void>("push", dim_val);
176176
}
177-
auto ml_tensor = jsepEnsureTensor(emscripten::val::undefined(), reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape, false);
177+
auto ml_tensor = webnnEnsureTensor(emscripten::val::undefined(), reinterpret_cast<intptr_t>(tensor.buffer), tensor.tensor_info.data_type, shape, false);
178178
promises.call<void>("push", ml_tensor);
179179
}
180180
auto ml_tensors = emscripten::val::global("Promise").call<emscripten::val>("all", promises).await();

onnxruntime/core/providers/webnn/builders/model_builder.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -140,13 +140,13 @@ Status ModelBuilder::RegisterInitializers() {
140140
ORT_RETURN_IF_ERROR(utils::GetExternalDataInfo(
141141
tensor, graph_viewer_.ModelPath(), external_file_path, data_offset, tensor_byte_size));
142142

143-
auto jsepRegisterMLConstant = emscripten::val::module_property("jsepRegisterMLConstant");
144-
operand = jsepRegisterMLConstant(emscripten::val(external_file_path),
145-
static_cast<int32_t>(data_offset),
146-
static_cast<int32_t>(tensor_byte_size),
147-
wnn_builder_,
148-
desc,
149-
should_convert_int64_to_int32);
143+
auto webnnRegisterMLConstant = emscripten::val::module_property("webnnRegisterMLConstant");
144+
operand = webnnRegisterMLConstant(emscripten::val(external_file_path),
145+
static_cast<int32_t>(data_offset),
146+
static_cast<int32_t>(tensor_byte_size),
147+
wnn_builder_,
148+
desc,
149+
should_convert_int64_to_int32);
150150
} else {
151151
if (tensor.has_raw_data()) {
152152
tensor_ptr = reinterpret_cast<std::byte*>(const_cast<char*>(tensor.raw_data().c_str()));
@@ -288,7 +288,7 @@ Status ModelBuilder::RegisterModelInputOutput(const NodeArg& node_arg, bool is_i
288288
desc.set("dataType", emscripten::val("int32"));
289289
}
290290
wnn_operands_.insert(std::make_pair(name, wnn_builder_.call<emscripten::val>("input", name, desc)));
291-
emscripten::val::module_property("jsepRegisterGraphInput")(name);
291+
emscripten::val::module_property("webnnRegisterGraphInput")(name);
292292
input_names_.push_back(name);
293293
} else {
294294
output_names_.push_back(name);

onnxruntime/core/providers/webnn/data_transfer.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,11 @@ common::Status DataTransfer::CopyTensor(const Tensor& src, Tensor& dst) const {
2929
const auto& dst_device = dst.Location().device;
3030

3131
if (dst_device.Type() == OrtDevice::GPU) {
32-
EM_ASM({ Module.jsepUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast<intptr_t>(src_data), bytes);
32+
EM_ASM({ Module.webnnUploadTensor($0, HEAPU8.subarray($1, $1 + $2)); }, dst_data, reinterpret_cast<intptr_t>(src_data), bytes);
3333
} else {
34-
auto jsepDownloadTensor = emscripten::val::module_property("jsepDownloadTensor");
34+
auto webnnDownloadTensor = emscripten::val::module_property("webnnDownloadTensor");
3535
auto subarray = emscripten::typed_memory_view(bytes, static_cast<char*>(dst_data));
36-
jsepDownloadTensor(reinterpret_cast<intptr_t>(src_data), subarray).await();
36+
webnnDownloadTensor(reinterpret_cast<intptr_t>(src_data), subarray).await();
3737
}
3838
}
3939

onnxruntime/core/providers/webnn/webnn_execution_provider.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,7 @@ class WebNNMemcpy : public OpKernel {
284284
explicit WebNNMemcpy(const OpKernelInfo& info) : OpKernel(info) {}
285285

286286
Status Compute(OpKernelContext* context) const override {
287-
auto jsepEnsureTensor = emscripten::val::module_property("jsepEnsureTensor");
287+
auto webnnEnsureTensor = emscripten::val::module_property("webnnEnsureTensor");
288288
const auto* X = context->Input<Tensor>(0);
289289
ORT_ENFORCE(X != nullptr, "Memcpy: input tensor is null");
290290
auto* Y = context->Output(0, X->Shape());
@@ -294,10 +294,10 @@ class WebNNMemcpy : public OpKernel {
294294
shape.call<void>("push", SafeInt<uint32_t>(dim).Ref());
295295
}
296296

297-
jsepEnsureTensor(emscripten::val::undefined(),
298-
reinterpret_cast<intptr_t>(Y->MutableDataRaw()),
299-
Y->GetElementType(),
300-
shape, false)
297+
webnnEnsureTensor(emscripten::val::undefined(),
298+
reinterpret_cast<intptr_t>(Y->MutableDataRaw()),
299+
Y->GetElementType(),
300+
shape, false)
301301
.await();
302302

303303
const auto* data_transfer = Info().GetDataTransferManager().GetDataTransfer(X->Location().device, Y->Location().device);

0 commit comments

Comments
 (0)