Skip to content

Commit 764c5b1

Browse files
HonrynaomiOvad
authored andcommitted
[WebNN] Remove constraints for Gemm's C input (microsoft#26273)
Now WebNN implementation for gemm's C operand has supported unidirectional broadcasting, which is align with ONNX spec. Removing constraints for Gemm's C input as which should be covered in ORT kernel.
1 parent b7d1051 commit 764c5b1

File tree

2 files changed

+1
-24
lines changed

2 files changed

+1
-24
lines changed

js/web/docs/webnn-operators.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ platforms. Check the [WebNN status](https://webmachinelearning.github.io/webnn-s
4646
| GatherElements | ai.onnx(11-12, 13+) | gatherElements | |
4747
| GatherND | ai.onnx(11, 12, 13+) | gatherND | Only supports 'batch_dims' == 0 |
4848
| Gelu | ai.onnx(20+) | gelu | |
49-
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | Only supports 1-D 'C' input |
49+
| Gemm | ai.onnx(7-8, 9-10, 11-12, 13+) | gemm | |
5050
| GlobalAveragePool | ai.onnx(7+) | averagePool2d | Only supports 4-D input |
5151
| GlobalMaxPool | ai.onnx(7+) | maxPool2d | Only supports 4-D input |
5252
| GlobalLpPool| ai.onnx(7+) | l2Pool2d | Only supports 4-D input, 'p' value is 2 |

onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -250,29 +250,6 @@ bool GemmOpBuilder::IsOpSupportedImpl(const GraphViewer&,
250250
std::vector<int64_t> c_shape;
251251
if (!GetShape(*input_defs[c_idx], c_shape, logger))
252252
return false;
253-
254-
size_t c_dim = c_shape.size();
255-
256-
if (c_dim > 1) {
257-
// TODO: Supports other shape of C.
258-
// Currently WebNN implementation in Chromium only supports 1-D C.
259-
return false;
260-
}
261-
if (c_dim == 0) {
262-
LOGS(logger, VERBOSE) << "C of Gemm is a scalar";
263-
} else {
264-
auto c_size = c_shape[c_dim - 1];
265-
NodeAttrHelper helper(node);
266-
const auto transB = helper.Get("transB", 0);
267-
if (c_size != (transB == 0 ? b_shape[1] : b_shape[0])) {
268-
LOGS(logger, VERBOSE) << "C of Gemm must be a vector of b_shape["
269-
<< (transB == 0 ? "1" : "0") << "]"
270-
<< " b_shape: [" << b_shape[0] << ", " << b_shape[1] << "]"
271-
<< " c_size: " << c_size;
272-
273-
return false;
274-
}
275-
}
276253
}
277254
}
278255

0 commit comments

Comments
 (0)