Skip to content

Commit 5fa8bd0

Browse files
authored
[WebNN] Add op support validation for decomposed WebNN ops (microsoft#23370)
- Some ONNX op are supported by decomposed WebNN ops, defines a `decomposed_op_map` map to specific decomposed WebNN ops list. - WebNN ops have various first input names such as 'a', 'input', 'inputs', etc. Defines a `webnn_op_first_input_name_map` map to record special names other than 'input', and a `GetWebNNOpFirstInputName` function to retrieve the first input name of a WebNN op. - Check if the input and output data types are supported by each decomposed WebNN op. - Remove the unnecessary `CheckSingleOp` function, because WebNN's `OpSupportLimits` has already covered op supported check.
1 parent 8c3e34d commit 5fa8bd0

25 files changed

+261
-112
lines changed

onnxruntime/core/providers/webnn/builders/helper.cc

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
107107
std::unordered_set<const Node*> supported_nodes;
108108

109109
for (const auto& node : graph_viewer.Nodes()) {
110-
bool supported = false;
111-
// Firstly check if platform supports the WebNN op.
112-
if (CheckSingleOp(node.OpType(), wnn_builder, device_type)) {
113-
supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
114-
}
110+
const bool supported = IsNodeSupported(node, graph_viewer, device_type, wnn_limits, logger);
115111
LOGS(logger, VERBOSE) << "Operator type: [" << node.OpType()
116112
<< "] index: [" << node.Index()
117113
<< "] name: [" << node.Name()
@@ -125,7 +121,7 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
125121
return supported_nodes;
126122
}
127123

128-
bool AreInputDataTypesSame(const std::string& op_type,
124+
bool AreInputDataTypesSame(const std::string_view op_type,
129125
gsl::span<const int32_t> input_types,
130126
const logging::Logger& logger) {
131127
for (size_t i = 1; i < input_types.size(); i++) {
@@ -145,46 +141,47 @@ bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& we
145141
if (it == onnx_to_webnn_data_type_map.end())
146142
return false;
147143

148-
std::string webnn_data_type = it->second;
144+
const std::string_view webnn_data_type = it->second;
149145

150146
// Check if WebNN supports the data type.
151-
emscripten::val is_supported = webnn_supported_data_types.call<emscripten::val>("includes",
152-
emscripten::val(webnn_data_type));
147+
emscripten::val is_supported =
148+
webnn_supported_data_types.call<emscripten::val>("includes", emscripten::val(std::string(webnn_data_type)));
153149
return is_supported.as<bool>();
154150
}
155151

156152
// Check if the input or output data type of ONNX node is supported by the WebNN operator.
157-
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
153+
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
158154
const int32_t onnx_data_type,
159155
const emscripten::val& wnn_limits,
160-
const std::string& webnn_input_output_name,
161-
const std::string& onnx_input_output_name,
156+
const std::string_view webnn_input_output_name,
157+
const std::string_view onnx_input_output_name,
162158
const logging::Logger& logger) {
163-
std::string webnn_op_type;
164-
if (!GetWebNNOpType(onnx_op_type, webnn_op_type))
165-
return false;
159+
const std::string_view webnn_op_type = GetWebNNOpType(onnx_op_type);
166160

167-
return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
161+
return !webnn_op_type.empty() &&
162+
IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits,
168163
webnn_input_output_name, onnx_input_output_name, logger);
169164
}
170165

171-
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
172-
const std::string& webnn_op_type,
166+
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
167+
const std::string_view webnn_op_type,
173168
const int32_t onnx_data_type,
174169
const emscripten::val& wnn_limits,
175-
const std::string& webnn_input_output_name,
176-
const std::string& onnx_input_output_name,
170+
const std::string_view webnn_input_output_name,
171+
const std::string_view onnx_input_output_name,
177172
const logging::Logger& logger) {
178-
if (wnn_limits[webnn_op_type].isUndefined()) {
173+
if (wnn_limits[std::string(webnn_op_type)].isUndefined()) {
179174
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now";
180175
return false;
181176
}
182-
if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) {
177+
178+
if (wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)].isUndefined()) {
183179
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter ["
184180
<< webnn_input_output_name << "]";
185181
return false;
186182
}
187-
if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) {
183+
if (!IsSupportedDataType(
184+
onnx_data_type, wnn_limits[std::string(webnn_op_type)][std::string(webnn_input_output_name)]["dataTypes"])) {
188185
LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: ["
189186
<< onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now";
190187
return false;

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 52 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -194,9 +194,16 @@ std::unordered_set<const Node*> GetSupportedNodes(const GraphViewer& graph_viewe
194194
const WebnnDeviceType device_type,
195195
const emscripten::val& wnn_limits,
196196
const logging::Logger& logger);
197-
// TODO(@Honry): Some ONNX ops are supported by decomposed WebNN ops,
198-
// we need to check the support of the decomposed ops.
199-
static const InlinedHashMap<std::string, std::string> op_map = {
197+
198+
// Some ONNX ops are supported by decomposed WebNN ops.
199+
const std::map<std::string_view, std::vector<std::string_view>> decomposed_op_map = {
200+
{"LRN", {"add", "averagePool2d", "div", "mul", "pad", "pow", "transpose"}},
201+
{"RotaryEmbedding", {"add", "concat", "gather", "mul", "reshape", "split"}},
202+
{"SimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
203+
{"SkipSimplifiedLayerNormalization", {"add", "div", "mul", "pow", "reduceMean", "sqrt"}},
204+
};
205+
// ONNX op type to WebNN op type mapping.
206+
const std::map<std::string_view, std::string_view> op_map = {
200207
{"Abs", "abs"},
201208
{"Add", "add"},
202209
{"And", "logicalAnd"},
@@ -247,7 +254,6 @@ static const InlinedHashMap<std::string, std::string> op_map = {
247254
{"Log", "log"},
248255
{"LpPool", "l2Pool2d"},
249256
{"LSTM", "lstm"},
250-
{"LRN", "averagePool2d"},
251257
{"MatMul", "matmul"},
252258
{"MatMulInteger", "matmulInteger"},
253259
{"Max", "max"},
@@ -275,17 +281,14 @@ static const InlinedHashMap<std::string, std::string> op_map = {
275281
{"Relu", "relu"},
276282
{"Reshape", "reshape"},
277283
{"Resize", "resample2d"},
278-
{"RotaryEmbedding", "gather"},
279284
{"ScatterElements", "scatterElements"},
280285
{"ScatterND", "scatterND"},
281286
{"Shape", "slice"},
282287
{"Sigmoid", "sigmoid"},
283288
{"Sign", "sign"},
284-
{"SimplifiedLayerNormalization", "layerNormalization"},
285289
{"Softplus", "softplus"},
286290
{"Softsign", "softsign"},
287291
{"Sin", "sin"},
288-
{"SkipSimplifiedLayerNormalization", "layerNormalization"},
289292
{"Slice", "slice"},
290293
{"Softmax", "softmax"},
291294
{"Split", "split"},
@@ -302,29 +305,46 @@ static const InlinedHashMap<std::string, std::string> op_map = {
302305
{"Xor", "logicalXor"},
303306
};
304307

305-
inline bool CheckSingleOp(const std::string& op_type, const emscripten::val& wnn_builder,
306-
const WebnnDeviceType device_type) {
307-
auto op_map_entry = op_map.find(op_type);
308-
// Returns false if the op_type is not listed in the op_map or
309-
// if the WebNN op has not been implemented in MLGraphBuilder in current browser.
310-
if (op_map_entry == op_map.end() || !wnn_builder[op_map_entry->second].as<bool>()) {
311-
return false;
312-
}
308+
// WebNN op name to its first input name mapping, only record the name that is different from "input".
309+
// This map is used to determine the first input name of a WebNN op and is utilized by OpSupportLimits.
310+
const std::map<std::string_view, std::string_view> webnn_op_first_input_name_map = {
311+
{"add", "a"},
312+
{"concat", "inputs"},
313+
{"div", "a"},
314+
{"equal", "a"},
315+
{"gemm", "a"},
316+
{"greater", "a"},
317+
{"greaterOrEqual", "a"},
318+
{"lesser", "a"},
319+
{"lesserOrEqual", "a"},
320+
{"logicalAnd", "a"},
321+
{"logicalNot", "a"},
322+
{"logicalOr", "a"},
323+
{"logicalXor", "a"},
324+
{"matmul", "a"},
325+
{"max", "a"},
326+
{"min", "a"},
327+
{"mul", "a"},
328+
{"pow", "a"},
329+
{"sub", "a"},
330+
{"where", "condition"},
331+
};
313332

314-
return true;
333+
// Retrieve the first input name of a WebNN op used for validating supported input data types.
334+
// WebNN ops have various first input names such as 'a', 'input', 'inputs', etc.
335+
// Special names other than 'input' are recorded in the webnn_op_first_input_name_map.
336+
inline std::string_view GetWebNNOpFirstInputName(const std::string_view webnn_op_type) {
337+
auto it = webnn_op_first_input_name_map.find(webnn_op_type);
338+
return (it != webnn_op_first_input_name_map.end()) ? it->second : "input";
315339
}
316340

317-
inline bool GetWebNNOpType(const std::string& op_type, std::string& webnn_op_type) {
341+
inline std::string_view GetWebNNOpType(const std::string_view op_type) {
318342
auto it = op_map.find(op_type);
319-
// Returns false if the op_type is not listed in the op_map.
320-
if (it == op_map.end()) {
321-
return false;
322-
}
323-
webnn_op_type = it->second;
324-
return true;
343+
// Return an empty string if the op_type is not listed in the op_map.
344+
return (it != op_map.end()) ? it->second : "";
325345
}
326346

327-
static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> onnx_to_webnn_data_type_map = {
347+
const std::map<ONNX_NAMESPACE::TensorProto_DataType, std::string_view> onnx_to_webnn_data_type_map = {
328348
{ONNX_NAMESPACE::TensorProto_DataType_INT4, "int4"},
329349
{ONNX_NAMESPACE::TensorProto_DataType_UINT4, "uint4"},
330350
{ONNX_NAMESPACE::TensorProto_DataType_BOOL, "uint8"},
@@ -338,22 +358,22 @@ static const InlinedHashMap<ONNX_NAMESPACE::TensorProto_DataType, std::string> o
338358
{ONNX_NAMESPACE::TensorProto_DataType_UINT64, "uint64"},
339359
};
340360

341-
bool AreInputDataTypesSame(const std::string& op_type,
361+
bool AreInputDataTypesSame(const std::string_view op_type,
342362
gsl::span<const int32_t> input_types,
343363
const logging::Logger& logger);
344364
bool IsSupportedDataType(const int32_t onnx_data_type, const emscripten::val& webnn_supported_data_types);
345-
bool IsDataTypeSupportedByOp(const std::string& onnx_op_type,
365+
bool IsDataTypeSupportedByOp(const std::string_view onnx_op_type,
346366
const int32_t onnx_data_type,
347367
const emscripten::val& wnn_limits,
348-
const std::string& webnn_input_output_name,
349-
const std::string& onnx_input_output_name,
368+
const std::string_view webnn_input_output_name,
369+
const std::string_view onnx_input_output_name,
350370
const logging::Logger& logger);
351-
bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type,
352-
const std::string& webnn_op_type,
371+
bool IsDataTypeSupportedByWebNNOp(const std::string_view onnx_op_type,
372+
const std::string_view webnn_op_type,
353373
const int32_t onnx_data_type,
354374
const emscripten::val& wnn_limits,
355-
const std::string& webnn_input_output_name,
356-
const std::string& onnx_input_output_name,
375+
const std::string_view webnn_input_output_name,
376+
const std::string_view onnx_input_output_name,
357377
const logging::Logger& logger);
358378

359379
bool GetBidirectionalBroadcastShape(std::vector<int64_t>& shape_a,

onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,12 +58,17 @@ bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializ
5858
const logging::Logger& logger) const {
5959
// We only check the type of input 0 by default, specific op builder can override this.
6060
const auto& input = *node.InputDefs()[0];
61-
const auto& op_type = node.OpType();
61+
const std::string_view op_type = node.OpType();
6262
int32_t input_type;
6363
if (!GetType(input, input_type, logger))
6464
return false;
65+
const std::string_view webnn_op_type = GetWebNNOpType(op_type);
66+
if (webnn_op_type.empty())
67+
return false;
6568

66-
return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "Input", logger);
69+
const std::string_view webnn_input_name = GetWebNNOpFirstInputName(webnn_op_type);
70+
return IsDataTypeSupportedByWebNNOp(op_type, webnn_op_type, input_type, wnn_limits,
71+
webnn_input_name, "input", logger);
6772
}
6873

6974
bool BaseOpBuilder::HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits,
@@ -83,7 +88,7 @@ bool BaseOpBuilder::HasSupportedOutputsImpl(const Node& node,
8388
const logging::Logger& logger) const {
8489
// We only check the type of output 0 by default, specific op builder can override this.
8590
const auto& output = *node.OutputDefs()[0];
86-
const auto& op_type = node.OpType();
91+
const std::string_view op_type = node.OpType();
8792
int32_t output_type;
8893
if (!GetType(output, output_type, logger))
8994
return false;

onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ Status BinaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const
6060
bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
6161
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
6262
const auto& input_defs = node.InputDefs();
63-
const auto& op_type = node.OpType();
63+
const std::string_view op_type = node.OpType();
6464
int32_t input0_type;
6565
int32_t input1_type;
6666

onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ class CastOpBuilder : public BaseOpBuilder {
1818
private:
1919
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
2020
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
21-
22-
// Operator support related.
23-
private:
24-
bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
25-
const emscripten::val& wnn_limits, const logging::Logger& logger) const override;
2621
};
2722

2823
// Add operator related.
@@ -85,25 +80,6 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
8580
return Status::OK();
8681
}
8782

88-
// Operator support related.
89-
bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
90-
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
91-
const auto& input_defs = node.InputDefs();
92-
const auto& op_type = node.OpType();
93-
int32_t input_type;
94-
95-
if (!GetType(*input_defs[0], input_type, logger))
96-
return false;
97-
98-
if (!IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "input", logger))
99-
return false;
100-
101-
NodeAttrHelper helper(node);
102-
// Check cast to type.
103-
const auto to_type = helper.Get("to", ONNX_NAMESPACE::TensorProto_DataType_UNDEFINED);
104-
return IsDataTypeSupportedByOp(op_type, to_type, wnn_limits, "output", "to", logger);
105-
}
106-
10783
void CreateCastOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) {
10884
op_registrations.builders.push_back(std::make_unique<CastOpBuilder>());
10985
op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());

onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder,
5858
bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
5959
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
6060
const auto& input_defs = node.InputDefs();
61-
const auto& op_type = node.OpType();
61+
const std::string_view op_type = node.OpType();
6262
int32_t input0_type;
6363

6464
if (!GetType(*input_defs[0], input0_type, logger))

onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers,
384384
bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node,
385385
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
386386
const auto& input_defs = node.InputDefs();
387-
const auto& op_type = node.OpType();
387+
const std::string_view op_type = node.OpType();
388388
int32_t input0_type; // input data type
389389
int32_t input1_type; // weight data type
390390
int32_t input2_type; // bias or x_zero_point data type

onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -739,7 +739,7 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* init
739739
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
740740
const auto& input_defs = node.InputDefs();
741741

742-
const auto& op_type = node.OpType();
742+
const std::string_view op_type = node.OpType();
743743
int32_t input0_type;
744744
int32_t input1_type;
745745
bool has_input1 = TensorExists(input_defs, 1);

onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet&
5454
const logging::Logger& logger) const {
5555
const auto& data = *node.InputDefs()[0];
5656
const auto& indices = *node.InputDefs()[1];
57-
const auto& op_type = node.OpType();
57+
const std::string_view op_type = node.OpType();
5858

5959
int32_t data_type;
6060
int32_t indices_type;

onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* in
5959
const emscripten::val& wnn_limits, const logging::Logger& logger) const {
6060
const auto& data = *node.InputDefs()[0];
6161
const auto& indices = *node.InputDefs()[1];
62-
const auto& op_type = node.OpType();
62+
const std::string_view op_type = node.OpType();
6363

6464
int32_t data_type;
6565
int32_t indices_type;

0 commit comments

Comments
 (0)