Skip to content

Commit 89b1099

Browse files
committed
fix old changes
1 parent f9a4066 commit 89b1099

File tree

2 files changed

+92
-111
lines changed

2 files changed

+92
-111
lines changed

backends/xnnpack/_passes/channels_last_tagged_reshape_pass.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import string
8-
from logging import FATAL
9-
from tokenize import String
107
from typing import Optional, Tuple
118

129
import torch
@@ -106,7 +103,7 @@ def requires_nchw_inputs(self, node: torch.fx.Node) -> bool:
106103
or node.name == "output"
107104
and node.args[0][0]
108105
.meta["val"]
109-
.is_contiguous() # Need to consider output trace so out matches
106+
.is_contiguous()
110107
)
111108

112109
def can_be_converted_to_nhwc(self, node: torch.fx.Node) -> bool:
@@ -287,12 +284,6 @@ def input_to_nhwc(
287284
return
288285
elif self.is_nhwc_node(input_node):
289286
return
290-
# if (
291-
# self.is_nhwc_node(input_node)
292-
# or input_node.op == "placeholder"
293-
# and not input_node.meta["val"][0].is_contiguous()
294-
# ):
295-
# return
296287

297288
if not self.can_be_converted_to_nhwc(input_node):
298289
raise AssertionError(
@@ -360,16 +351,6 @@ def input_to_nchw(
360351
return
361352
elif self.is_nchw_node(input_node):
362353
return
363-
# TODO
364-
# meta trace happens before passes. At the end of pass, meta gets regenerated. eager mode assumes in/out stay same for conv. Linear has implicit nchw conv
365-
# if (
366-
# self.is_nchw_node(
367-
# input_node
368-
# ) # This is triggering as x (placeholder) is tagged as nchw
369-
# or input_node.op == "placeholder"
370-
# and input_node.meta["val"][0].is_contiguous()
371-
# ):
372-
# return
373354

374355
if ChannelsLastTaggedReshapePass.PARTNER_NODE in input_node.meta:
375356
# Already has an associated NCHW node

backends/xnnpack/test/passes/test_channels_last_tagged_reshape.py

Lines changed: 91 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,20 @@ def setUp(self):
4444
)
4545
dynamic_quant_name = "executorch_exir_dialects_edge__ops_quantized_decomposed_quantize_per_tensor_tensor"
4646

47-
# def test_fp32_channels_last_tagged_reshape_pass(self):
48-
# for module, num_reshape in self.modules.items():
49-
# (
50-
# Tester(module, (torch.randn(1, 1, 6, 6),))
51-
# .export()
52-
# .to_edge()
53-
# .run_passes(self.PassStage)
54-
# .check_count(
55-
# {
56-
# self.to_copy_name: num_reshape,
57-
# }
58-
# )
59-
# .run_method_and_compare_outputs()
60-
# )
47+
def test_fp32_channels_last_tagged_reshape_pass(self):
48+
for module, num_reshape in self.modules.items():
49+
(
50+
Tester(module, (torch.randn(1, 1, 6, 6),))
51+
.export()
52+
.to_edge()
53+
.run_passes(self.PassStage)
54+
.check_count(
55+
{
56+
self.to_copy_name: num_reshape,
57+
}
58+
)
59+
.run_method_and_compare_outputs()
60+
)
6161

6262
class LinearConv(torch.nn.Module):
6363
def __init__(self):
@@ -141,26 +141,26 @@ def test_nchw_input_on_nhwc_op(self):
141141

142142
tester.export().to_edge_transform_and_lower().to_executorch().serialize().run_method_and_compare_outputs()
143143

144-
# def test_qs8_channels_last_tagged_reshape_pass(self):
145-
# for module, num_reshape in self.modules.items():
146-
# (
147-
# Tester(module, (torch.randn(1, 1, 6, 6),))
148-
# .quantize()
149-
# .export()
150-
# .to_edge()
151-
# .run_passes(self.PassStage)
152-
# .check(
153-
# [
154-
# self.quant_name,
155-
# self.dequant_name,
156-
# self.to_copy_name,
157-
# self.quant_name,
158-
# self.dequant_name,
159-
# ]
160-
# * num_reshape
161-
# )
162-
# .run_method_and_compare_outputs()
163-
# )
144+
def test_qs8_channels_last_tagged_reshape_pass(self):
145+
for module, num_reshape in self.modules.items():
146+
(
147+
Tester(module, (torch.randn(1, 1, 6, 6),))
148+
.quantize()
149+
.export()
150+
.to_edge()
151+
.run_passes(self.PassStage)
152+
.check(
153+
[
154+
self.quant_name,
155+
self.dequant_name,
156+
self.to_copy_name,
157+
self.quant_name,
158+
self.dequant_name,
159+
]
160+
* num_reshape
161+
)
162+
.run_method_and_compare_outputs()
163+
)
164164

165165
class ConvRelu(torch.nn.Module):
166166
def __init__(self):
@@ -171,39 +171,39 @@ def __init__(self):
171171
def forward(self, x):
172172
return self.relu(self.conv(x))
173173

174-
# def test_fp32_channels_last_tagged_reshape_pass_conv_relu(self):
175-
# (
176-
# Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
177-
# .export()
178-
# .to_edge()
179-
# .run_passes(self.PassStage)
180-
# .check(
181-
# [self.to_copy_name, self.conv_name, self.relu_name, self.to_copy_name]
182-
# )
183-
# .run_method_and_compare_outputs()
184-
# )
174+
def test_fp32_channels_last_tagged_reshape_pass_conv_relu(self):
175+
(
176+
Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
177+
.export()
178+
.to_edge()
179+
.run_passes(self.PassStage)
180+
.check(
181+
[self.to_copy_name, self.conv_name, self.relu_name, self.to_copy_name]
182+
)
183+
.run_method_and_compare_outputs()
184+
)
185185

186-
# def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self):
187-
# (
188-
# Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
189-
# .quantize()
190-
# .export()
191-
# .to_edge()
192-
# .run_passes(self.PassStage)
193-
# .check(
194-
# [
195-
# self.to_copy_name,
196-
# self.quant_name,
197-
# self.dequant_name,
198-
# self.conv_name,
199-
# self.relu_name,
200-
# self.quant_name,
201-
# self.dequant_name,
202-
# self.to_copy_name,
203-
# ]
204-
# )
205-
# .run_method_and_compare_outputs()
206-
# )
186+
def test_qs8_channels_last_tagged_reshape_pass_conv_relu(self):
187+
(
188+
Tester(self.ConvRelu().eval(), (torch.randn(1, 1, 6, 6),))
189+
.quantize()
190+
.export()
191+
.to_edge()
192+
.run_passes(self.PassStage)
193+
.check(
194+
[
195+
self.to_copy_name,
196+
self.quant_name,
197+
self.dequant_name,
198+
self.conv_name,
199+
self.relu_name,
200+
self.quant_name,
201+
self.dequant_name,
202+
self.to_copy_name,
203+
]
204+
)
205+
.run_method_and_compare_outputs()
206+
)
207207

208208
class Conv2dBnHardtanhMeanSequenceModule(torch.nn.Module):
209209
def __init__(self):
@@ -278,28 +278,28 @@ def __init__(self):
278278
def forward(self, x):
279279
return self.conv(x)
280280

281-
# def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
282-
# (
283-
# Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
284-
# .quantize(
285-
# Quantize(
286-
# quantization_config=get_symmetric_quantization_config(
287-
# is_dynamic=True
288-
# )
289-
# )
290-
# )
291-
# .export()
292-
# .to_edge()
293-
# .run_passes(self.PassStage)
294-
# .check(
295-
# [
296-
# self.to_copy_name,
297-
# self.choose_qparams_name,
298-
# self.dynamic_quant_name,
299-
# self.dequant_name,
300-
# self.conv_name,
301-
# self.to_copy_name,
302-
# ]
303-
# )
304-
# .run_method_and_compare_outputs()
305-
# )
281+
def test_dq_conv2d_channels_last_tagged_reshape_pass(self) -> None:
282+
(
283+
Tester(self.Conv2dDynamicQuant().eval(), (torch.randn(1, 3, 8, 8),))
284+
.quantize(
285+
Quantize(
286+
quantization_config=get_symmetric_quantization_config(
287+
is_dynamic=True
288+
)
289+
)
290+
)
291+
.export()
292+
.to_edge()
293+
.run_passes(self.PassStage)
294+
.check(
295+
[
296+
self.to_copy_name,
297+
self.choose_qparams_name,
298+
self.dynamic_quant_name,
299+
self.dequant_name,
300+
self.conv_name,
301+
self.to_copy_name,
302+
]
303+
)
304+
.run_method_and_compare_outputs()
305+
)

0 commit comments

Comments
 (0)