Skip to content

Commit dbfbebe

Browse files
authored
[WebNN] Handle in-memory external data (microsoft#25079)
### Description Some initializers are stored as in-memory external data, WebNN EP should support these initializers. ### Motivation and Context This PR: - Added `HasExternalDataInMemory` check for external data to avoid unexpected error. - Wrapped the `UnpackInitializerData` to make it compatible with external data. Fixed microsoft#25078
1 parent 74126d1 commit dbfbebe

File tree

13 files changed

+58
-40
lines changed

13 files changed

+58
-40
lines changed

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include <emscripten.h>
1616
#include <emscripten/val.h>
1717

18+
using onnxruntime::common::Status;
1819
namespace onnxruntime {
1920

2021
class GraphViewer;
@@ -92,14 +93,33 @@ inline std::vector<T> GetNarrowedIntfromInt64(gsl::span<const int64_t> int64_vec
9293
return vec;
9394
}
9495

96+
bool inline UnpackInitializerData(const ONNX_NAMESPACE::TensorProto& initializer,
97+
std::vector<uint8_t>& unpacked_tensor,
98+
const GraphViewer& graph_viewer,
99+
const logging::Logger& logger) {
100+
Status status = Status::OK();
101+
if (utils::HasExternalData(initializer)) {
102+
status = onnxruntime::utils::UnpackInitializerData(initializer, graph_viewer.ModelPath(), unpacked_tensor);
103+
} else {
104+
status = onnxruntime::utils::UnpackInitializerData(initializer, unpacked_tensor);
105+
}
106+
107+
if (!status.IsOK()) {
108+
LOGS(logger, ERROR) << "Error while unpacking initializer data: " << status.ErrorMessage();
109+
return false;
110+
}
111+
112+
return true;
113+
}
114+
95115
template <typename T>
96-
bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector<T>& array, const logging::Logger& logger) {
116+
bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector<T>& array,
117+
const GraphViewer& graph_viewer, const logging::Logger& logger) {
97118
std::vector<uint8_t> unpacked_tensor;
98-
auto status = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor);
99-
if (!status.IsOK()) {
100-
LOGS(logger, ERROR) << "Error while unpacking shape: " << status.ErrorMessage();
119+
if (!UnpackInitializerData(tensor, unpacked_tensor, graph_viewer, logger)) {
101120
return false;
102121
}
122+
103123
const auto& dims = tensor.dims();
104124
if (dims.size() != 1) {
105125
LOGS(logger, VERBOSE) << "The tensor must be 1D.";
@@ -130,13 +150,13 @@ bool ReadIntArrayFrom1DTensor(const onnx::TensorProto& tensor, std::vector<T>& a
130150
return true;
131151
}
132152

133-
inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::val& scalar, const logging::Logger& logger) {
153+
inline bool ReadScalarTensorData(const onnx::TensorProto& tensor, emscripten::val& scalar,
154+
const GraphViewer& graph_viewer, const logging::Logger& logger) {
134155
std::vector<uint8_t> unpacked_tensor;
135-
auto status = onnxruntime::utils::UnpackInitializerData(tensor, unpacked_tensor);
136-
if (!status.IsOK()) {
137-
LOGS(logger, ERROR) << "Error while unpacking tensor: " << status.ErrorMessage();
156+
if (!UnpackInitializerData(tensor, unpacked_tensor, graph_viewer, logger)) {
138157
return false;
139158
}
159+
140160
switch (tensor.data_type()) {
141161
case ONNX_NAMESPACE::TensorProto_DataType_BOOL:
142162
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:

onnxruntime/core/providers/webnn/builders/impl/cumsum_op_builder.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Status CumSumOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
5050
const std::string axis_name = GetTensorName(input_defs, 1);
5151
const auto axis_tensor = *initializers.at(axis_name);
5252
emscripten::val axis = emscripten::val::undefined();
53-
ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, logger), "Cannot get axis value");
53+
ORT_RETURN_IF_NOT(ReadScalarTensorData(axis_tensor, axis, model_builder.GetGraphViewer(), logger),
54+
"Cannot get axis value");
5455
int64_t webnn_axis = HandleNegativeAxis(axis.as<int64_t>(), input_rank);
5556

5657
NodeAttrHelper helper(node);

onnxruntime/core/providers/webnn/builders/impl/expand_op_builder.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,8 @@ Status ExpandOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
4444
const auto& initializers(model_builder.GetInitializerTensors());
4545
const auto& shape_tensor = *initializers.at(input_defs[1]->Name());
4646
std::vector<int64_t> new_shape;
47-
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(shape_tensor, new_shape, logger), "Cannot get shape.");
47+
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(shape_tensor, new_shape, model_builder.GetGraphViewer(), logger),
48+
"Cannot get shape.");
4849
emscripten::val input = model_builder.GetOperand(input_defs[0]->Name());
4950
std::vector<int64_t> input_shape;
5051
ORT_RETURN_IF_NOT(GetShape(*input_defs[0], input_shape, logger), "Cannot get input's shape.");
@@ -84,8 +85,7 @@ bool ExpandOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer,
8485
}
8586

8687
std::vector<int64_t> new_shape;
87-
if (!ReadIntArrayFrom1DTensor(shape_tensor, new_shape, logger)) {
88-
LOGS(logger, VERBOSE) << "Cannot get shape.";
88+
if (!ReadIntArrayFrom1DTensor(shape_tensor, new_shape, graph_viewer, logger)) {
8989
return false;
9090
}
9191
if (std::any_of(new_shape.begin(), new_shape.end(), [](int64_t dimension) { return dimension == 0; })) {

onnxruntime/core/providers/webnn/builders/impl/gqa_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ bool GroupQueryAttentionOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_vi
392392

393393
const auto total_sequence_length_tensor = *total_sequence_length_initializer;
394394
emscripten::val total_sequence_length = emscripten::val::undefined();
395-
if (!ReadScalarTensorData(total_sequence_length_tensor, total_sequence_length, logger)) {
395+
if (!ReadScalarTensorData(total_sequence_length_tensor, total_sequence_length, graph_viewer, logger)) {
396396
return false;
397397
}
398398

onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Node
143143

144144
const auto& sequence_lens_tensor = *seq_initializer;
145145
std::vector<int32_t> sequence_lens;
146-
if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) {
147-
LOGS(logger, ERROR) << "Cannot read sequence lens tensor";
146+
if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, graph_viewer, logger)) {
148147
return false;
149148
}
150149
if (!std::all_of(sequence_lens.begin(), sequence_lens.end(),

onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,8 +149,7 @@ bool LstmOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer, const Nod
149149

150150
const auto& sequence_lens_tensor = *sequence_lens_init;
151151
std::vector<int32_t> sequence_lens;
152-
if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, logger)) {
153-
LOGS(logger, ERROR) << "Cannot read sequence lens tensor";
152+
if (!ReadIntArrayFrom1DTensor(sequence_lens_tensor, sequence_lens, graph_viewer, logger)) {
154153
return false;
155154
}
156155
if (std::any_of(sequence_lens.begin(), sequence_lens.end(),

onnxruntime/core/providers/webnn/builders/impl/pad_op_builder.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,24 +83,27 @@ Status PadOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
8383
const auto opset = node.SinceVersion();
8484
// From opset 11, pads, constant value and axes are inputs.
8585
if (opset >= 11) {
86+
const auto& graph_viewer = model_builder.GetGraphViewer();
8687
ORT_RETURN_IF(input_defs.size() < 2, "Pads is required at opset ", opset);
8788
std::vector<int64_t> pads;
8889
const auto& pads_tensor = *initializers.at(input_defs[1]->Name());
89-
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(pads_tensor, pads, logger), "Error while read pads tensor");
90+
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(pads_tensor, pads, graph_viewer, logger),
91+
"Error while reading pads tensor");
9092

9193
// Constant value and axes are optional. Make sure they are not empty.
9294
if (!GetTensorName(input_defs, 2).empty()) {
9395
const auto value_tensor = *initializers.at(input_defs[2]->Name());
9496
emscripten::val value = emscripten::val::object();
95-
ORT_RETURN_IF_NOT(ReadScalarTensorData(value_tensor, value, logger), "Cannot read constant value");
97+
ORT_RETURN_IF_NOT(ReadScalarTensorData(value_tensor, value, graph_viewer, logger), "Cannot read constant value");
9698
options.set("value", value);
9799
}
98100

99101
if (!GetTensorName(input_defs, 3).empty()) {
100102
const auto input_rank = input_shape.size();
101103
std::vector<int64_t> axes;
102104
const auto& axes_tensor = *initializers.at(input_defs[3]->Name());
103-
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(axes_tensor, axes, logger), "Error while read axes tensor");
105+
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(axes_tensor, axes, graph_viewer, logger),
106+
"Error while reading axes tensor");
104107
std::vector<size_t> axes_index;
105108
std::transform(
106109
axes.begin(), axes.end(), std::back_inserter(axes_index),

onnxruntime/core/providers/webnn/builders/impl/reshape_op_builder.cc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,7 @@ bool ReshapeOpBuilder::IsOpSupportedImpl(const GraphViewer& graph_viewer,
9393

9494
const auto& perm_tensor = *perm_init;
9595
std::vector<uint8_t> unpacked_tensor;
96-
auto status = onnxruntime::utils::UnpackInitializerData(perm_tensor, unpacked_tensor);
97-
if (!status.IsOK()) {
98-
LOGS(logger, ERROR) << "Error while unpacking perm_tensor: " << status.ErrorMessage();
96+
if (!UnpackInitializerData(perm_tensor, unpacked_tensor, graph_viewer, logger)) {
9997
return false;
10098
}
10199

onnxruntime/core/providers/webnn/builders/impl/resize_op_builder.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,9 @@ bool GetResizeScalesAndAxes(const GraphViewer& graph_viewer,
6969
}
7070

7171
std::vector<uint8_t> unpacked_tensor;
72-
auto status = onnxruntime::utils::UnpackInitializerData(scales_tensor, unpacked_tensor);
73-
if (!status.IsOK()) {
74-
LOGS(logger, ERROR) << "Error while unpacking scales_tensor: " << status.ErrorMessage();
72+
if (!UnpackInitializerData(scales_tensor, unpacked_tensor, graph_viewer, logger)) {
7573
return false;
76-
}
74+
};
7775
const float* scales_data = reinterpret_cast<const float*>(unpacked_tensor.data());
7876

7977
if (has_axes) {
@@ -137,9 +135,7 @@ bool GetResizeSizesAndAxes(const GraphViewer& graph_viewer,
137135
}
138136

139137
std::vector<uint8_t> unpacked_tensor;
140-
auto status = onnxruntime::utils::UnpackInitializerData(sizes_tensor, unpacked_tensor);
141-
if (!status.IsOK()) {
142-
LOGS(logger, ERROR) << "Error while unpacking sizes_tensor: " << status.ErrorMessage();
138+
if (!UnpackInitializerData(sizes_tensor, unpacked_tensor, graph_viewer, logger)) {
143139
return false;
144140
}
145141
const int64_t* sizes_data = reinterpret_cast<const int64_t*>(unpacked_tensor.data());

onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,8 @@ Status SliceOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
7272
input_name = input_defs[input_idx]->Name();
7373
const auto& initializers(model_builder.GetInitializerTensors());
7474
const auto& tensor = *initializers.at(input_name);
75-
if (!ReadIntArrayFrom1DTensor(tensor, data, logger)) {
76-
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Data type for starts and ends inputs is not supported in this build.");
77-
}
75+
ORT_RETURN_IF_NOT(ReadIntArrayFrom1DTensor(tensor, data, model_builder.GetGraphViewer(), logger),
76+
"Data type for starts or ends inputs is not supported in this build.");
7877

7978
return Status::OK();
8079
};
@@ -176,7 +175,7 @@ bool SliceOpBuilder::HasSupportedInputsImpl(const GraphViewer& graph_viewer, con
176175
if (TensorExists(input_defs, 4)) {
177176
std::vector<int64_t> steps;
178177
const auto* init = graph_viewer.GetConstantInitializer(input_defs[4]->Name());
179-
if (!init || !ReadIntArrayFrom1DTensor(*init, steps, logger))
178+
if (!init || !ReadIntArrayFrom1DTensor(*init, steps, graph_viewer, logger))
180179
return false;
181180
if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) {
182181
if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) {

0 commit comments

Comments
 (0)