@@ -229,46 +229,46 @@ def forward(self, x):
229229 x = torch .mean (x , (- 1 , - 2 ), keepdim = True )
230230 return x
231231
232- # def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq(self):
233- # Copy #1 is for input to conv, nchw -> nhwc
234- # Copy #2 is for conv to _native_batch_norm_legit_no_training, nhwc -> nchw
235- # Copy #3 is for input to mean, nchw -> nhwc
236- # Copy #4 is for output, nhwc -> nchw
237-
238- # The graph looks like:
239- # graph():
240- # %arg0_1 : [#users=1] = placeholder[target=arg0_1]
241- # %aten__to_copy_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {memory_format: torch.channels_last})
242- # %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
243- # %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
244- # %aten_convolution_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten__to_copy_default, %_param_constant0, %_param_constant1, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {})
245- # %aten__to_copy_default_1 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_convolution_default,), kwargs = {memory_format: torch.contiguous_format})
246- # %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
247- # %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
248- # %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
249- # %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1]
250- # %aten__native_batch_norm_legit_no_training_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten__to_copy_default_1, %_param_constant2, %_param_constant3, %_tensor_constant0, %_tensor_constant1, 0.1, 1e-05), kwargs = {})
251- # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {})
252- # %aten_hardtanh_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0, 6), kwargs = {})
253- # %aten__to_copy_default_2 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_hardtanh_default,), kwargs = {memory_format: torch.channels_last})
254- # %aten_mean_dim : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten__to_copy_default_2, [-1, -2], True), kwargs = {})
255- # %aten__to_copy_default_3 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_mean_dim,), kwargs = {memory_format: torch.contiguous_format})
256- # return [aten__to_copy_default_3]
257- # (
258- # Tester(
259- # self.Conv2dBnHardtanhMeanSequenceModule().eval(),
260- # (torch.randn(1, 1, 6, 6),),
261- # )
262- # .export()
263- # .to_edge()
264- # .run_passes(self.PassStage)
265- # .check_count(
266- # {
267- # self.to_copy_name: 4,
268- # }
269- # )
270- # .run_method_and_compare_outputs()
271- # )
232+ def test_fp32_channels_last_tagged_reshape_pass_conv_bn_hardtanh_mean_seq (self ):
233+ # Copy #1 is for input to conv, nchw -> nhwc
234+ # Copy #2 is for conv to _native_batch_norm_legit_no_training, nhwc -> nchw
235+ # Copy #3 is for input to mean, nchw -> nhwc
236+ # Copy #4 is for output, nhwc -> nchw
237+
238+ # The graph looks like:
239+ # graph():
240+ # %arg0_1 : [#users=1] = placeholder[target=arg0_1]
241+ # %aten__to_copy_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%arg0_1,), kwargs = {memory_format: torch.channels_last})
242+ # %_param_constant0 : [#users=1] = get_attr[target=_param_constant0]
243+ # %_param_constant1 : [#users=1] = get_attr[target=_param_constant1]
244+ # %aten_convolution_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.convolution.default](args = (%aten__to_copy_default, %_param_constant0, %_param_constant1, [2, 2], [1, 1], [1, 1], False, [0, 0], 1), kwargs = {})
245+ # %aten__to_copy_default_1 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_convolution_default,), kwargs = {memory_format: torch.contiguous_format})
246+ # %_param_constant2 : [#users=1] = get_attr[target=_param_constant2]
247+ # %_param_constant3 : [#users=1] = get_attr[target=_param_constant3]
248+ # %_tensor_constant0 : [#users=1] = get_attr[target=_tensor_constant0]
249+ # %_tensor_constant1 : [#users=1] = get_attr[target=_tensor_constant1]
250+ # %aten__native_batch_norm_legit_no_training_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._native_batch_norm_legit_no_training.default](args = (%aten__to_copy_default_1, %_param_constant2, %_param_constant3, %_tensor_constant0, %_tensor_constant1, 0.1, 1e-05), kwargs = {})
251+ # %getitem : [#users=1] = call_function[target=operator.getitem](args = (%aten__native_batch_norm_legit_no_training_default, 0), kwargs = {})
252+ # %aten_hardtanh_default : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.hardtanh.default](args = (%getitem, 0, 6), kwargs = {})
253+ # %aten__to_copy_default_2 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_hardtanh_default,), kwargs = {memory_format: torch.channels_last})
254+ # %aten_mean_dim : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten.mean.dim](args = (%aten__to_copy_default_2, [-1, -2], True), kwargs = {})
255+ # %aten__to_copy_default_3 : [#users=1] = call_function[target=executorch.exir.dialects.edge._ops.aten._to_copy.default](args = (%aten_mean_dim,), kwargs = {memory_format: torch.contiguous_format})
256+ # return [aten__to_copy_default_3]
257+ (
258+ Tester (
259+ self .Conv2dBnHardtanhMeanSequenceModule ().eval (),
260+ (torch .randn (1 , 1 , 6 , 6 ),),
261+ )
262+ .export ()
263+ .to_edge ()
264+ .run_passes (self .PassStage )
265+ .check_count (
266+ {
267+ self .to_copy_name : 4 ,
268+ }
269+ )
270+ .run_method_and_compare_outputs ()
271+ )
272272
273273 class Conv2dDynamicQuant (torch .nn .Module ):
274274 def __init__ (self ):
0 commit comments