Skip to content

Commit fc4d0b6

Browse files
JakeStevensfacebook-github-bot
authored andcommitted
Remove SoftmaxQuantizer (#14089)
Summary: Pull Request resolved: #14089 Softmax not supported on current platforms. This introduces quant/dequants into the graph, which are unneeded. This diff removes the quantizer until supported. Differential Revision: D81964057
1 parent 245630a commit fc4d0b6

File tree

5 files changed

+6
-44
lines changed

5 files changed

+6
-44
lines changed

backends/nxp/quantizer/neutron_quantizer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535
ReshapePattern,
3636
SharedSpecPattern,
3737
SigmoidPattern,
38-
SoftMaxPattern,
3938
TanhInPlacePattern,
4039
TanhPattern,
4140
ViewPattern,
@@ -225,7 +224,6 @@ def __init__(self):
225224
NeutronAtenQuantizer(ReluInPlacePattern(), static_qconfig),
226225
NeutronAtenQuantizer(ReshapePattern(), static_qconfig),
227226
NeutronAtenQuantizer(SigmoidPattern(), static_qconfig),
228-
NeutronAtenQuantizer(SoftMaxPattern(), static_qconfig),
229227
NeutronAtenQuantizer(TanhPattern(), static_qconfig),
230228
NeutronAtenQuantizer(TanhInPlacePattern(), static_qconfig),
231229
NeutronAtenQuantizer(ViewPattern(), static_qconfig),

backends/nxp/tests/ir/edge_passes/test_remove_io_quant_ops_pass.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,13 @@ def test_remove_io_quant_ops_pass__cifarnet():
5959
)
6060

6161
nodes = list(exec_prog.exported_program().graph.nodes)
62-
assert len(nodes) == 11
62+
assert len(nodes) == 9
6363
assert (
6464
nodes[0].meta["val"].dtype == torch.int8
6565
), "Input tensor doesn't have type INT8."
66+
# Currently, softmax is not quantized
6667
assert (
67-
nodes[10].meta["val"][0].dtype == torch.int8
68+
nodes[8].meta["val"][0].dtype == torch.float32
6869
), "Output tensor doesn't have type INT8."
6970

7071
assert (

backends/nxp/tests/test_edge_passes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def unsupported_target(*_):
7272
exported_program = epm.exported_program()
7373

7474
nodes = list(exported_program.graph_module.graph.nodes)
75-
assert len(nodes) == 28
75+
assert len(nodes) == 26
7676

7777
view_copy_indices = _find_view_copy_node_indices(nodes)
7878

backends/nxp/tests/test_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def test_conv_fc_softmax__to_executorch_program():
2727

2828
delegation_info = get_delegation_info(program.graph_module)
2929
assert delegation_info.num_delegated_subgraphs == 1
30-
assert delegation_info.num_non_delegated_nodes == 11
30+
assert delegation_info.num_non_delegated_nodes == 7
3131
assert delegation_info.num_delegated_nodes == 13
3232

3333
for node in program.graph.nodes:
@@ -43,7 +43,7 @@ def test_cifarnet():
4343

4444
delegation_info = get_delegation_info(exec_prog.exported_program().graph_module)
4545
assert delegation_info.num_delegated_subgraphs == 1
46-
assert delegation_info.num_non_delegated_nodes == 11
46+
assert delegation_info.num_non_delegated_nodes == 7
4747
assert delegation_info.num_delegated_nodes == 45
4848

4949
nodes = list(exec_prog.exported_program().graph.nodes)

backends/nxp/tests/test_quantizer.py

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -131,43 +131,6 @@ def test_quantizer_maxpool2d():
131131
assert input_quant == output_quant
132132

133133

134-
def test_quantizer_softmax():
135-
model = models.SoftmaxModule(dim=0)
136-
model.eval()
137-
138-
example_input = (torch.ones(1, 10),)
139-
quantizer = NeutronQuantizer()
140-
graph_module = torch.export.export(model, example_input, strict=True).module()
141-
142-
# noinspection PyTypeChecker
143-
m = prepare_pt2e(graph_module, quantizer)
144-
m(*example_input)
145-
m = convert_pt2e(m)
146-
147-
# Dry run
148-
m(*example_input)
149-
150-
nodes = list(m.graph.nodes)
151-
assert len(nodes) == 7
152-
# Check if QDQ pattern:
153-
assert nodes[3].name == "softmax"
154-
assert (
155-
_get_target_name(nodes[3].args[0])
156-
== "torch.ops.quantized_decomposed.dequantize_per_tensor.default"
157-
)
158-
assert (
159-
_get_target_name(nodes[4])
160-
== "torch.ops.quantized_decomposed.quantize_per_tensor.default"
161-
)
162-
assert nodes[4].args[0].name == "softmax"
163-
164-
# Check output quantization
165-
scale, zp, _, _, dtype = nodes[4].args[1:]
166-
assert scale == 1.0 / 256.0
167-
assert zp == -128
168-
assert dtype == torch.int8
169-
170-
171134
def test_quantizer_single_maxpool2d():
172135
model = models.MaxPool2dModule()
173136
model.eval()

0 commit comments

Comments
 (0)