@@ -17,9 +17,6 @@ namespace webnn {
1717
1818class ConvOpBuilder : public BaseOpBuilder {
1919 // Add operator related.
20- public:
21- void AddInitializersToSkip (ModelBuilder& model_builder, const Node& node) const override ;
22-
2320 private:
2421 Status AddToModelBuilderImpl (ModelBuilder& model_builder, const Node& node,
2522 const logging::Logger& logger) const override ORT_MUST_USE_RESULT;
@@ -34,13 +31,6 @@ class ConvOpBuilder : public BaseOpBuilder {
3431 const logging::Logger& logger) const override ;
3532};
3633
37- void ConvOpBuilder::AddInitializersToSkip (ModelBuilder& model_builder, const Node& node) const {
38- // skip the weight for conv as we need to transpose for preferred layout NHWC.
39- if (model_builder.GetPreferredLayout () == DataLayout::NHWC) {
40- model_builder.AddInitializerToSkip (node.InputDefs ()[1 ]->Name ()); // W
41- }
42- }
43-
4434// Helper functions
4535common::Status SetConvBaseOptions (ModelBuilder& model_builder,
4636 const Node& node, emscripten::val& options,
@@ -49,7 +39,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
4939 const std::vector<int64_t >& strides,
5040 const std::vector<int64_t >& dilations,
5141 std::vector<int64_t >& pads,
52- const bool is_nhwc,
5342 const bool is_conv1d,
5443 const logging::Logger& logger) {
5544 NodeAttrHelper helper (node);
@@ -63,7 +52,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
6352 // Calculate explicit padding for autoPad.
6453 if (AutoPadType::SAME_UPPER == auto_pad_type || AutoPadType::SAME_LOWER == auto_pad_type) {
6554 ORT_RETURN_IF_ERROR (HandleAutoPad (input_shape, weight_shape[2 ], weight_shape[3 ],
66- pads, strides, dilations, auto_pad_type, pads_out, !is_nhwc ));
55+ pads, strides, dilations, auto_pad_type, pads_out));
6756 pads = pads_out;
6857 }
6958 } else if (op_type == " ConvTranspose" ) {
@@ -84,7 +73,7 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
8473 // Otherwise compute the output shape, as well as the pads if the auto_pad attribute is SAME_UPPER/SAME_LOWER.
8574 ORT_RETURN_IF_ERROR (ComputeConvTransposePadsAndOutputShape (input_shape, weight_shape[2 ], weight_shape[3 ],
8675 pads, strides, dilations, output_padding,
87- auto_pad_type, pads_out, output_shape, !is_nhwc ));
76+ auto_pad_type, pads_out, output_shape));
8877
8978 if (output_shape[0 ] != -1 && output_shape[1 ] != -1 ) {
9079 options.set (" outputSizes" , emscripten::val::array (GetNarrowedIntFromInt64<uint32_t >(output_shape)));
@@ -113,90 +102,6 @@ common::Status SetConvBaseOptions(ModelBuilder& model_builder,
113102 return Status::OK ();
114103}
115104
116- // Both depthwise Conv and ConvTranspose share the same logic to add the layout.
117- Status AddInitializerInNewLayout (ModelBuilder& model_builder,
118- const std::string& name,
119- bool is_conv,
120- bool is_conv1d) {
121- const auto & tensor = *model_builder.GetInitializerTensors ().at (name);
122- auto data_type = tensor.data_type ();
123-
124- const auto & shape = tensor.dims ();
125- std::vector<uint32_t > dims =
126- GetNarrowedIntFromInt64<uint32_t >(std::vector<int64_t >(std::begin (shape), std::end (shape)));
127-
128- if (is_conv1d) {
129- // Support conv1d by prepending a 1 size dimension.
130- dims.push_back (1 );
131- }
132-
133- const uint8_t * src = nullptr ;
134- Initializer unpacked_tensor (tensor, model_builder.GetGraphViewer ().ModelPath ());
135- src = unpacked_tensor.DataAsByteSpan ().data ();
136- const auto out_t = dims[0 ], in_t = dims[1 ],
137- h_t = dims[2 ], w_t = dims[3 ];
138- std::vector<uint32_t > dest_shape;
139- if (is_conv)
140- dest_shape = {out_t , h_t , w_t , in_t }; // L_0231
141- else
142- dest_shape = {in_t , h_t , w_t , out_t }; // L_1230 for depthwise conv and convTranspose weight
143-
144- SafeInt<size_t > num_elements = SafeInt<size_t >(Product (dest_shape));
145-
146- size_t element_size{0 };
147- switch (data_type) {
148- case ONNX_NAMESPACE::TensorProto_DataType_UINT8:
149- element_size = sizeof (uint8_t );
150- break ;
151- case ONNX_NAMESPACE::TensorProto_DataType_INT8:
152- element_size = sizeof (int8_t );
153- break ;
154- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
155- element_size = sizeof (uint16_t );
156- break ;
157- case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
158- element_size = sizeof (float );
159- break ;
160- default :
161- break ;
162- }
163- std::unique_ptr<uint8_t []> buffer_holder (new uint8_t [element_size * num_elements]);
164- uint8_t * buffer = buffer_holder.get ();
165-
166- for (uint32_t out = 0 ; out < out_t ; out++) {
167- for (uint32_t in = 0 ; in < in_t ; in++) {
168- for (uint32_t h = 0 ; h < h_t ; h++) {
169- for (uint32_t w = 0 ; w < w_t ; w++) {
170- auto onnx_idx = out * in_t * h_t * w_t +
171- in * h_t * w_t +
172- h * w_t +
173- w;
174-
175- uint32_t wnn_idx;
176- if (is_conv == 1 ) { // L_0231
177- wnn_idx = out * h_t * w_t * in_t +
178- h * w_t * in_t +
179- w * in_t +
180- in;
181- } else { // L_1230 for depthwise conv weight
182- wnn_idx = in * h_t * w_t * out_t +
183- h * w_t * out_t +
184- w * out_t +
185- out;
186- }
187-
188- for (size_t i = 0 ; i < element_size; i++) {
189- buffer[element_size * wnn_idx + i] = src[element_size * onnx_idx + i];
190- }
191- }
192- }
193- }
194- }
195- ORT_RETURN_IF_ERROR (model_builder.AddOperandFromPersistMemoryBuffer (name, buffer, num_elements * element_size,
196- dest_shape, data_type));
197- return Status::OK ();
198- }
199-
200105// Add operator related.
201106
202107Status ConvOpBuilder::AddToModelBuilderImpl (ModelBuilder& model_builder, const Node& node,
@@ -213,31 +118,25 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
213118 std::vector<int64_t > weight_shape;
214119 ORT_RETURN_IF_NOT (GetShape (*input_defs[1 ], weight_shape, logger), " Cannot get weight shape" );
215120 const auto & weight_name = input_defs[1 ]->Name ();
121+ emscripten::val filter = model_builder.GetOperand (weight_name);
216122
217123 NodeAttrHelper helper (node);
218124 auto strides = helper.Get (" strides" , std::vector<int64_t >{1 , 1 });
219125 auto dilations = helper.Get (" dilations" , std::vector<int64_t >{1 , 1 });
220126 auto pads = helper.Get (" pads" , std::vector<int64_t >{0 , 0 , 0 , 0 });
221127
222- const bool is_nhwc = model_builder.GetPreferredLayout () == DataLayout::NHWC;
223128 const bool is_conv1d = input_shape.size () == 3 && weight_shape.size () == 3 ;
224- const bool is_constant_weight = Contains (initializers, weight_name);
225129
226130 emscripten::val common_options = emscripten::val::object ();
227131 // Support conv1d by prepending a 1 or 2 size dimensions.
228132 if (is_conv1d) {
229133 // Reshape input.
230- if (is_nhwc) {
231- // For NHWC preferred layout, the input has been transposed.
232- // For conv1d it is NCD1 -> ND1C, so we need to prepend 1 to the index 2.
233- input_shape.insert (input_shape.begin () + 2 , 1 );
234- } else {
235- input_shape.push_back (1 );
236- }
237- std::vector<uint32_t > new_shape = GetNarrowedIntFromInt64<uint32_t >(input_shape);
134+ input_shape.push_back (1 );
135+ std::vector<uint32_t > new_input_shape = GetNarrowedIntFromInt64<uint32_t >(input_shape);
238136 common_options.set (" label" , node.Name () + " _reshape_input" );
239137 input = model_builder.GetBuilder ().call <emscripten::val>(" reshape" , input,
240- emscripten::val::array (new_shape), common_options);
138+ emscripten::val::array (new_input_shape),
139+ common_options);
241140
242141 weight_shape.resize (4 , 1 ); // Ensure 4D by appending 1's if needed.
243142 strides.resize (2 , 1 ); // Ensure 2D by appending 1's if needed.
@@ -246,68 +145,21 @@ Status ConvOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N
246145 pads.insert (pads.begin () + 1 , 0 );
247146 pads.push_back (0 );
248147 }
148+
149+ // Reshape weight to 4D for conv1d.
150+ // The weight_shape has been appended 1's, reshape weight operand.
151+ std::vector<uint32_t > new_weight_shape = GetNarrowedIntFromInt64<uint32_t >(weight_shape);
152+ common_options.set (" label" , node.Name () + " _reshape_filter" );
153+ filter = model_builder.GetBuilder ().call <emscripten::val>(" reshape" ,
154+ filter,
155+ emscripten::val::array (new_weight_shape),
156+ common_options);
249157 }
250158
251159 emscripten::val options = emscripten::val::object ();
252160 options.set (" label" , node.Name ());
253161 ORT_RETURN_IF_ERROR (SetConvBaseOptions (
254- model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_nhwc, is_conv1d, logger));
255- bool depthwise = false ;
256- if (op_type == " Conv" || op_type == " ConvInteger" ) {
257- int groups = options[" groups" ].as <int >();
258- if (is_nhwc) {
259- depthwise = (groups == input_shape[3 ] && groups != 1 );
260- options.set (" inputLayout" , emscripten::val (" nhwc" ));
261- if (is_constant_weight) {
262- ORT_RETURN_IF_ERROR (AddInitializerInNewLayout (model_builder, weight_name, !depthwise, is_conv1d));
263- }
264- if (!depthwise) {
265- options.set (" filterLayout" , emscripten::val (" ohwi" ));
266- } else {
267- options.set (" filterLayout" , emscripten::val (" ihwo" ));
268- }
269- }
270- } else { // ConvTranspose
271- if (is_nhwc) {
272- options.set (" inputLayout" , emscripten::val (" nhwc" ));
273- options.set (" filterLayout" , emscripten::val (" ohwi" ));
274- if (is_constant_weight) {
275- ORT_RETURN_IF_ERROR (AddInitializerInNewLayout (model_builder, weight_name, false , is_conv1d));
276- }
277- }
278- }
279-
280- emscripten::val filter = model_builder.GetOperand (weight_name);
281-
282- if (is_conv1d) {
283- // Reshape weight to 4D for conv1d.
284- if (!is_nhwc || !is_constant_weight) {
285- // The weight_shape has been appended 1's, reshape weight operand.
286- std::vector<uint32_t > new_shape = GetNarrowedIntFromInt64<uint32_t >(weight_shape);
287- common_options.set (" label" , node.Name () + " _reshape_filter" );
288- filter = model_builder.GetBuilder ().call <emscripten::val>(" reshape" ,
289- filter,
290- emscripten::val::array (new_shape),
291- common_options);
292- }
293- }
294-
295- if (is_nhwc && !is_constant_weight) {
296- // For NHWC preferred layout, if the weight is input:
297- // - Transpose it from iohw -> ohwi for convTranspose.
298- // - Transpose it from oihw -> ihwo for depthwise conv.
299- // - Transpose it from oihw -> ohwi for conv.
300- std::vector<uint32_t > perm (4 );
301- if (op_type == " ConvTranspose" || depthwise) {
302- perm = {1 , 2 , 3 , 0 }; // L_1230 for depthwise conv and convTranspose weight
303- } else {
304- perm = {0 , 2 , 3 , 1 }; // L_0231
305- }
306- emscripten::val transpose_options = emscripten::val::object ();
307- transpose_options.set (" permutation" , emscripten::val::array (perm));
308- transpose_options.set (" label" , node.Name () + " _transpose_filter" );
309- filter = model_builder.GetBuilder ().call <emscripten::val>(" transpose" , filter, transpose_options);
310- }
162+ model_builder, node, options, input_shape, weight_shape, strides, dilations, pads, is_conv1d, logger));
311163
312164 if (op_type == " Conv" ) {
313165 output = model_builder.GetBuilder ().call <emscripten::val>(" conv2d" , input, filter, options);
0 commit comments