Skip to content

Commit 932e175

Browse files
Honrygedoensmax
authored andcommitted
[WebNN] Remove NHWC preferred layout (microsoft#25679)
Chromium has implemented constant folding for transpose which eliminates the performance penalty introduced by additional transpose ops for context with NHWC preferred layout. Chromium CL: https://chromium-review.googlesource.com/c/chromium/src/+/6774969 Now we can safety remove the NHWC optimization in WebNN EP.
1 parent 1865637 commit 932e175

File tree

12 files changed

+69
-366
lines changed

12 files changed

+69
-366
lines changed

onnxruntime/core/providers/webnn/builders/helper.h

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,6 @@ WebnnDeviceType DeviceTypeFromString(const std::string_view& device_type);
3838
// Collects all the initializer tensors in the subGraph and its ancestor graphs.
3939
InitializedTensorSet CollectAllInitializedTensors(const GraphViewer& graph_viewer);
4040

41-
inline std::vector<int64_t> convertAxesFromNCHWtoNHWC(const std::vector<int64_t>& axes) {
42-
constexpr std::array<int64_t, 4> nchw_to_nhwc = {0, 3, 1, 2};
43-
std::vector<int64_t> new_axes;
44-
new_axes.reserve(axes.size());
45-
for (int64_t axis : axes) {
46-
if (axis >= nchw_to_nhwc.size()) {
47-
ORT_THROW("Invalid axis value: ", axis);
48-
}
49-
new_axes.push_back(nchw_to_nhwc[static_cast<size_t>(axis)]);
50-
}
51-
return new_axes;
52-
}
53-
5441
inline std::vector<int64_t> HandleNegativeAxes(const std::vector<int64_t>& axes, size_t input_size) {
5542
std::vector<int64_t> new_axes(axes.size());
5643
for (size_t i = 0; i < axes.size(); ++i) {

onnxruntime/core/providers/webnn/builders/impl/builder_utils.cc

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,9 @@ common::Status ComputeConvPads(const std::vector<int64_t> input_shape,
1818
const std::vector<int64_t>& onnx_strides,
1919
const std::vector<int64_t>& onnx_dilations,
2020
AutoPadType auto_pad_type,
21-
std::vector<int64_t>& pads_out,
22-
bool use_nchw) {
23-
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
24-
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
21+
std::vector<int64_t>& pads_out) {
22+
const int64_t input_size_y = input_shape[2];
23+
const int64_t input_size_x = input_shape[3];
2524
const int64_t stride_y = onnx_strides[0];
2625
const int64_t stride_x = onnx_strides[1];
2726
const int64_t dilation_y = onnx_dilations[0];
@@ -53,17 +52,12 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
5352
const std::vector<int64_t>& onnx_strides,
5453
const std::vector<int64_t>& onnx_dilations,
5554
AutoPadType auto_pad_type,
56-
std::vector<int64_t>& pads_out,
57-
bool use_nchw) {
58-
if (AutoPadType::SAME_UPPER == auto_pad_type) {
59-
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
60-
onnx_pads, onnx_strides, onnx_dilations,
61-
AutoPadType::SAME_UPPER, pads_out, use_nchw));
62-
} else {
63-
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
64-
onnx_pads, onnx_strides, onnx_dilations,
65-
AutoPadType::SAME_LOWER, pads_out, use_nchw));
66-
}
55+
std::vector<int64_t>& pads_out) {
56+
AutoPadType pad_type = (AutoPadType::SAME_UPPER == auto_pad_type) ? AutoPadType::SAME_UPPER : AutoPadType::SAME_LOWER;
57+
58+
ORT_RETURN_IF_ERROR(ComputeConvPads(input_shape, weight_size_y, weight_size_x,
59+
onnx_pads, onnx_strides, onnx_dilations,
60+
pad_type, pads_out));
6761
return Status::OK();
6862
}
6963

@@ -110,10 +104,9 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
110104
const std::vector<int64_t>& onnx_output_padding,
111105
AutoPadType auto_pad_type,
112106
std::vector<int64_t>& pads_out,
113-
std::vector<int64_t>& output_shape_out,
114-
bool use_nchw) {
115-
const int64_t input_size_y = use_nchw ? input_shape[2] : input_shape[1];
116-
const int64_t input_size_x = use_nchw ? input_shape[3] : input_shape[2];
107+
std::vector<int64_t>& output_shape_out) {
108+
const int64_t input_size_y = input_shape[2];
109+
const int64_t input_size_x = input_shape[3];
117110
const int64_t stride_y = onnx_strides[0];
118111
const int64_t stride_x = onnx_strides[1];
119112
const int64_t dilation_y = onnx_dilations[0];

onnxruntime/core/providers/webnn/builders/impl/builder_utils.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,7 @@ common::Status HandleAutoPad(const std::vector<int64_t> input_shape,
2121
const std::vector<int64_t>& onnx_strides,
2222
const std::vector<int64_t>& onnx_dilations,
2323
AutoPadType auto_pad_type,
24-
std::vector<int64_t>& pads_out,
25-
bool use_nchw) ORT_MUST_USE_RESULT;
24+
std::vector<int64_t>& pads_out) ORT_MUST_USE_RESULT;
2625

2726
// Compute pads and output shape for ConvTranspose.
2827
common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t> input_shape,
@@ -34,8 +33,7 @@ common::Status ComputeConvTransposePadsAndOutputShape(const std::vector<int64_t>
3433
const std::vector<int64_t>& onnx_output_padding,
3534
AutoPadType auto_pad_type,
3635
std::vector<int64_t>& pads_out,
37-
std::vector<int64_t>& output_shape_out,
38-
bool use_nchw) ORT_MUST_USE_RESULT;
36+
std::vector<int64_t>& output_shape_out) ORT_MUST_USE_RESULT;
3937

4038
} // namespace webnn
4139
} // namespace onnxruntime

onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc

Lines changed: 17 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,6 @@ namespace webnn {
1717

1818
class ConvOpBuilder : public BaseOpBuilder {
1919
// Add operator related.
20-
public:
21-
void AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const override;
22-
2320
private:
2421
Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
2522
const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
@@ -34,13 +31,6 @@ class ConvOpBuilder : public BaseOpBuilder {
3431
const logging::Logger& logger) const override;
3532
};
3633

37-
void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const {
38-
// skip the weight for conv as we need to transpose for preferred layout NHWC.
39-
if (model_builder.GetPreferredLayout() == DataLayout::NHWC) {
40-
model_builder.AddInitializerToSkip(node.InputDefs()[1]->Name()); // W
41-
}
42-
}
43-
4434
// Helper functions
4535
common::Status SetConvBaseOptions(ModelBuilder& model_builder,
4636
const Node& node, emscripten::val& options,
@@ -49,7 +39,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
4939
const std::vector<int64_t>& strides,
5040
const std::vector<int64_t>& dilations,
5141
std::vector<int64_t>& pads,
52-
const bool is_nhwc,
5342
const bool is_conv1d,
5443
const logging::Logger& logger) {
5544
NodeAttrHelper helper(node);
@@ -63,7 +52,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
6352
// Calculate explicit padding for autoPad.
6453
if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
6554
ORT_RETURN_IF_ERROR(HandleAutoPad(input_shape, weight_shape[2], weight_shape[3],
66-
pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc));
55+
pads, strides, dilations, auto_pad_type, pads_out));
6756
pads = pads_out;
6857
}
6958
} else if (op_type == "ConvTranspose") {
@@ -84,7 +73,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
8473
// Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER.
8574
ORT_RETURN_IF_ERROR(ComputeConvTransposePadsAndOutputShape(input_shape, weight_shape[2], weight_shape[3],
8675
pads, strides, dilations, output_padding,
87-
auto_pad_type, pads_out, output_shape, !is_nhwc));
76+
auto_pad_type, pads_out, output_shape));
8877

8978
if (output_shape[0] != -1 && output_shape[1] != -1) {
9079
options.set("outputSizes", emscripten::val::array(GetNarrowedIntFromInt64<uint32_t>(output_shape)));
@@ -113,90 +102,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
113102
return Status::OK();
114103
}
115104

116-
// Both depthwise Conv and ConvTranspose share the same logic to add the layout.
117-
Status AddInitializerInNewLayout(ModelBuilder& model_builder,
118-
const std::string& name,
119-
bool is_conv,
120-
bool is_conv1d) {
121-
const auto& tensor = *model_builder.GetInitializerTensors().at(name);
122-
auto data_type = tensor.data_type();
123-
124-
const auto& shape = tensor.dims();
125-
std::vector<uint32_t> dims =
126-
GetNarrowedIntFromInt64<uint32_t>(std::vector<int64_t>(std::begin(shape), std::end(shape)));
127-
128-
if (is_conv1d) {
129-
// Support conv1d by prepending a 1 size dimension.
130-
dims.push_back(1);
131-
}
132-
133-
const uint8_t* src = nullptr;
134-
Initializer unpacked_tensor(tensor, model_builder.GetGraphViewer().ModelPath());
135-
src = unpacked_tensor.DataAsByteSpan().data();
136-
const auto out_t = dims[0], in_t = dims[1],
137-
h_t = dims[2], w_t = dims[3];
138-
std::vector<uint32_t> dest_shape;
139-
if (is_conv)
140-
dest_shape = {out_t, h_t, w_t, in_t}; // L_0231
141-
else
142-
dest_shape = {in_t, h_t, w_t, out_t}; // L_1230 for depthwise conv and convTranspose weight
143-
144-
SafeInt<size_t> num_elements = SafeInt<size_t>(Product(dest_shape));
145-
146-
size_t element_size{0};
147-
switch (data_type) {
148-
case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
149-
element_size = sizeof(uint8_t);
150-
break;
151-
case ONNX_NAMESPACE::TensorProto_DataType_INT8:
152-
element_size = sizeof(int8_t);
153-
break;
154-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
155-
element_size = sizeof(uint16_t);
156-
break;
157-
case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
158-
element_size = sizeof(float);
159-
break;
160-
default:
161-
break;
162-
}
163-
std::unique_ptr<uint8_t[]> buffer_holder(new uint8_t[element_size * num_elements]);
164-
uint8_t* buffer = buffer_holder.get();
165-
166-
for (uint32_t out = 0; out < out_t; out++) {
167-
for (uint32_t in = 0; in < in_t; in++) {
168-
for (uint32_t h = 0; h < h_t; h++) {
169-
for (uint32_t w = 0; w < w_t; w++) {
170-
auto onnx_idx = out * in_t * h_t * w_t +
171-
in * h_t * w_t +
172-
h * w_t +
173-
w;
174-
175-
uint32_t wnn_idx;
176-
if (is_conv == 1) { // L_0231
177-
wnn_idx = out * h_t * w_t * in_t +
178-
h * w_t * in_t +
179-
w * in_t +
180-
in;
181-
} else { // L_1230 for depthwise conv weight
182-
wnn_idx = in * h_t * w_t * out_t +
183-
h * w_t * out_t +
184-
w * out_t +
185-
out;
186-
}
187-
188-
for (size_t i = 0; i < element_size; i++) {
189-
buffer[element_size * wnn_idx + i] = src[element_size * onnx_idx + i];
190-
}
191-
}
192-
}
193-
}
194-
}
195-
ORT_RETURN_IF_ERROR(model_builder.AddOperandFromPersistMemoryBuffer(name, buffer, num_elements * element_size,
196-
dest_shape, data_type));
197-
return Status::OK();
198-
}
199-
200105
// Add operator related.
201106

202107
Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node,
@@ -213,31 +118,25 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
213118
std::vector<int64_t> weight_shape;
214119
ORT_RETURN_IF_NOT(GetShape(*input_defs[1], weight_shape, logger), "Cannot get weight shape");
215120
const auto& weight_name = input_defs[1]->Name();
121+
emscripten::val filter = model_builder.GetOperand(weight_name);
216122

217123
NodeAttrHelper helper(node);
218124
auto strides = helper.Get("strides", std::vector<int64_t>{1, 1});
219125
auto dilations = helper.Get("dilations", std::vector<int64_t>{1, 1});
220126
auto pads = helper.Get("pads", std::vector<int64_t>{0, 0, 0, 0});
221127

222-
const bool is_nhwc = model_builder.GetPreferredLayout() == DataLayout::NHWC;
223128
const bool is_conv1d = input_shape.size() == 3 && weight_shape.size() == 3;
224-
const bool is_constant_weight = Contains(initializers, weight_name);
225129

226130
emscripten::val common_options = emscripten::val::object();
227131
// Support conv1d by prepending a 1 or 2 size dimensions.
228132
if (is_conv1d) {
229133
// Reshape input.
230-
if (is_nhwc) {
231-
// For NHWC preferred layout, the input has been transposed.
232-
// For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2.
233-
input_shape.insert(input_shape.begin() + 2, 1);
234-
} else {
235-
input_shape.push_back(1);
236-
}
237-
std::vector<uint32_t> new_shape = GetNarrowedIntFromInt64<uint32_t>(input_shape);
134+
input_shape.push_back(1);
135+
std::vector<uint32_t> new_input_shape = GetNarrowedIntFromInt64<uint32_t>(input_shape);
238136
common_options.set("label", node.Name() + "_reshape_input");
239137
input = model_builder.GetBuilder().call<emscripten::val>("reshape", input,
240-
emscripten::val::array(new_shape), common_options);
138+
emscripten::val::array(new_input_shape),
139+
common_options);
241140

242141
weight_shape.resize(4, 1); // Ensure 4D by appending 1's if needed.
243142
strides.resize(2, 1); // Ensure 2D by appending 1's if needed.
@@ -246,68 +145,21 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
246145
pads.insert(pads.begin() + 1, 0);
247146
pads.push_back(0);
248147
}
148+
149+
// Reshape weight to 4D for conv1d.
150+
// The weight_shape has been appended 1's, reshape weight operand.
151+
std::vector<uint32_t> new_weight_shape = GetNarrowedIntFromInt64<uint32_t>(weight_shape);
152+
common_options.set("label", node.Name() + "_reshape_filter");
153+
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
154+
filter,
155+
emscripten::val::array(new_weight_shape),
156+
common_options);
249157
}
250158

251159
emscripten::val options = emscripten::val::object();
252160
options.set("label", node.Name());
253161
ORT_RETURN_IF_ERROR(SetConvBaseOptions(
254-
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger));
255-
bool depthwise = false;
256-
if (op_type == "Conv" || op_type == "ConvInteger") {
257-
int groups = options["groups"].as<int>();
258-
if (is_nhwc) {
259-
depthwise = (groups == input_shape[3] && groups != 1);
260-
options.set("inputLayout", emscripten::val("nhwc"));
261-
if (is_constant_weight) {
262-
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, !depthwise, is_conv1d));
263-
}
264-
if (!depthwise) {
265-
options.set("filterLayout", emscripten::val("ohwi"));
266-
} else {
267-
options.set("filterLayout", emscripten::val("ihwo"));
268-
}
269-
}
270-
} else { // ConvTranspose
271-
if (is_nhwc) {
272-
options.set("inputLayout", emscripten::val("nhwc"));
273-
options.set("filterLayout", emscripten::val("ohwi"));
274-
if (is_constant_weight) {
275-
ORT_RETURN_IF_ERROR(AddInitializerInNewLayout(model_builder, weight_name, false, is_conv1d));
276-
}
277-
}
278-
}
279-
280-
emscripten::val filter = model_builder.GetOperand(weight_name);
281-
282-
if (is_conv1d) {
283-
// Reshape weight to 4D for conv1d.
284-
if (!is_nhwc || !is_constant_weight) {
285-
// The weight_shape has been appended 1's, reshape weight operand.
286-
std::vector<uint32_t> new_shape = GetNarrowedIntFromInt64<uint32_t>(weight_shape);
287-
common_options.set("label", node.Name() + "_reshape_filter");
288-
filter = model_builder.GetBuilder().call<emscripten::val>("reshape",
289-
filter,
290-
emscripten::val::array(new_shape),
291-
common_options);
292-
}
293-
}
294-
295-
if (is_nhwc && !is_constant_weight) {
296-
// For NHWC preferred layout, if the weight is input:
297-
// - Transpose it from iohw -> ohwi for convTranspose.
298-
// - Transpose it from oihw -> ihwo for depthwise conv.
299-
// - Transpose it from oihw -> ohwi for conv.
300-
std::vector<uint32_t> perm(4);
301-
if (op_type == "ConvTranspose" || depthwise) {
302-
perm = {1, 2, 3, 0}; // L_1230 for depthwise conv and convTranspose weight
303-
} else {
304-
perm = {0, 2, 3, 1}; // L_0231
305-
}
306-
emscripten::val transpose_options = emscripten::val::object();
307-
transpose_options.set("permutation", emscripten::val::array(perm));
308-
transpose_options.set("label", node.Name() + "_transpose_filter");
309-
filter = model_builder.GetBuilder().call<emscripten::val>("transpose", filter, transpose_options);
310-
}
162+
model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_conv1d, logger));
311163

312164
if (op_type == "Conv") {
313165
output = model_builder.GetBuilder().call<emscripten::val>("conv2d", input, filter, options);

0 commit comments

Comments
 (0)