Skip to content

Commit 9ba1b5d

Browse files
authored
Update AdvanceQuantizeOpAboveDefChainPass, PostponeDequantizeOpBelowUseChainPass, and ScalarToTensorPass to correctly set modified bit
Differential Revision: D87900822 Pull Request resolved: #16230
1 parent 3dd80c1 commit 9ba1b5d

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

backends/cadence/aot/reorder_ops.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,9 @@ def advancing_feasible(self, quant_node: torch.fx.Node):
299299
# All the conditions satisfied, we advance.
300300
return True
301301

302-
def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
302+
def advance_quantize_op(self, graph_module: torch.fx.GraphModule) -> bool:
303303
graph = graph_module.graph
304+
modified = False
304305
for node in reversed(graph.nodes):
305306
if get_overload_packet(node.target) not in (
306307
exir_ops.edge.quantized_decomposed.quantize_per_tensor,
@@ -339,15 +340,19 @@ def advance_quantize_op(self, graph_module: torch.fx.GraphModule):
339340
# We can safely remove the quant node and trivially quantizable op
340341
graph.erase_node(node)
341342
graph.erase_node(trivially_quantizable_op)
343+
modified = True
342344

343-
graph_module.recompile()
344-
graph_module.graph.eliminate_dead_code()
345+
return modified
345346

346347
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
347348
self.graph_module = graph_module
348-
self.advance_quantize_op(graph_module)
349-
result = super().call(graph_module)
350-
return result
349+
modified = self.advance_quantize_op(graph_module)
350+
if modified:
351+
graph_module.recompile()
352+
graph_module.graph.eliminate_dead_code()
353+
return super().call(graph_module)
354+
355+
return PassResult(graph_module, False)
351356

352357

353358
@register_cadence_pass(CadencePassAttribute(opt_level=1))
@@ -474,14 +479,21 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
474479
# the graph (up to 3 times max, to avoid potential infinite loops)
475480
self.graph_module = graph_module
476481
iter_count = 0
477-
modified = True
482+
local_modified = False
483+
overall_modified = False
484+
485+
while local_modified or iter_count == 0:
486+
local_modified = self.postpone_dequantize_op(self.graph_module)
487+
overall_modified |= local_modified
488+
489+
if local_modified:
490+
self.graph_module = super().call(self.graph_module).graph_module
478491

479-
while modified and iter_count < 3:
480-
modified = self.postpone_dequantize_op(self.graph_module)
481-
self.graph_module = super().call(self.graph_module).graph_module
482492
iter_count += 1
493+
if iter_count == 3:
494+
break
483495

484-
return super().call(self.graph_module)
496+
return PassResult(self.graph_module, overall_modified)
485497

486498

487499
@register_cadence_pass(CadencePassAttribute(opt_level=1))

backends/cadence/aot/tests/test_reorder_ops_passes.py

Lines changed: 28 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -286,13 +286,14 @@ def test_advance_branched_quantize(self) -> None:
286286
@torch.no_grad()
287287
def test_advance_quantize(self) -> None:
288288
builder = GraphBuilder()
289-
x = builder.placeholder("x", torch.randn(16, 1, 6, 32, dtype=torch.float32))
290-
weights = builder.placeholder(
291-
"weights", torch.randint(-128, 127, (32, 32), dtype=torch.int8)
292-
)
289+
x_data = torch.randn(16, 1, 32, 6, dtype=torch.float32)
290+
weight_data = torch.randint(-128, 127, (32, 32), dtype=torch.int8)
291+
x = builder.placeholder("x", x_data)
292+
weights = builder.placeholder("weights", weight_data)
293293
full = builder.call_operator(
294294
op=exir_ops.edge.aten.full.default,
295295
args=([1], -7),
296+
kwargs={"dtype": torch.int32},
296297
)
297298
full_1 = builder.call_operator(
298299
op=exir_ops.edge.aten.full.default,
@@ -304,7 +305,8 @@ def test_advance_quantize(self) -> None:
304305
)
305306
full_3 = builder.call_operator(
306307
op=exir_ops.edge.aten.full.default,
307-
args=([12], 0.0),
308+
args=([1], 0),
309+
kwargs={"dtype": torch.int32},
308310
)
309311
permute = builder.call_operator(
310312
op=exir_ops.edge.aten.permute_copy.default,
@@ -337,8 +339,13 @@ def test_advance_quantize(self) -> None:
337339

338340
p1 = AdvanceQuantizeOpAboveDefInBranchPass()
339341
tmp_graph = cast(PassResult, p1(original_graph)).graph_module
340-
p2 = AdvanceQuantizeOpAboveDefChainPass()
341-
converted_graph = cast(PassResult, p2(tmp_graph)).graph_module
342+
result = transform_and_check_numerics(
343+
tmp_graph,
344+
(x_data, weight_data),
345+
AdvanceQuantizeOpAboveDefChainPass(),
346+
)
347+
self.assertFalse(result.modified)
348+
converted_graph = result.graph_module
342349
# Assert that permute node is now the successor of the quant node.
343350
self.assertTrue(
344351
get_node_pos(
@@ -349,13 +356,14 @@ def test_advance_quantize(self) -> None:
349356

350357
def test_postpone_dequantize1(self) -> None:
351358
builder = GraphBuilder()
352-
x = builder.placeholder("x", torch.randn(1, 16, 32, 6, dtype=torch.float32))
353-
weights = builder.placeholder(
354-
"weights", torch.randint(-128, 127, (6, 6), dtype=torch.int8)
355-
)
359+
x_data = torch.randn(1, 16, 32, 6, dtype=torch.float32)
360+
weight_data = torch.randint(-128, 127, (6, 6), dtype=torch.int8)
361+
x = builder.placeholder("x", x_data)
362+
weights = builder.placeholder("weights", weight_data)
356363
full = builder.call_operator(
357364
op=exir_ops.edge.aten.full.default,
358365
args=([1], -7),
366+
kwargs={"dtype": torch.int32},
359367
)
360368
full_1 = builder.call_operator(
361369
op=exir_ops.edge.aten.full.default,
@@ -367,7 +375,8 @@ def test_postpone_dequantize1(self) -> None:
367375
)
368376
full_3 = builder.call_operator(
369377
op=exir_ops.edge.aten.full.default,
370-
args=([12], 0.0),
378+
args=([1], 0),
379+
kwargs={"dtype": torch.int32},
371380
)
372381
quantize_per_tensor = builder.call_operator(
373382
op=exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
@@ -397,8 +406,13 @@ def test_postpone_dequantize1(self) -> None:
397406
)
398407
builder.output([permute])
399408
original_graph = builder.get_graph_module()
400-
p = PostponeDequantizeOpBelowUseChainPass()
401-
converted_graph = cast(PassResult, p(original_graph)).graph_module
409+
result = transform_and_check_numerics(
410+
original_graph,
411+
(x_data, weight_data),
412+
PostponeDequantizeOpBelowUseChainPass(),
413+
)
414+
self.assertTrue(result.modified)
415+
converted_graph = result.graph_module
402416
# Assert that dequant node is now the successor of the permute node.
403417
self.assertTrue(
404418
get_node_pos(converted_graph, exir_ops.edge.aten.permute_copy.default)

0 commit comments

Comments
 (0)