Skip to content

Commit 4844c1c

Browse files
chunnienccopybara-github
authored andcommitted
Fix RemoveSDPACompositeZeroMaskPass with new decomp rules
PiperOrigin-RevId: 719918842
1 parent 50f279c commit 4844c1c

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

ai_edge_torch/generative/fx_passes/remove_sdpa_zero_mask_pass.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,17 @@
1616
from ai_edge_torch import lowertools
1717
import torch
1818

19+
fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros.default)
20+
fx_infra.decomp.remove_pre_convert_decomp(torch.ops.aten.zeros_like.default)
21+
1922

2023
class RemoveSDPACompositeZeroMaskPass(fx_infra.ExportedProgramPassBase):
2124

2225
def is_zero_tensor_node(self, node: torch.fx.Node):
23-
return node.target == torch.ops.aten.zeros.default
26+
return node.target in (
27+
torch.ops.aten.zeros.default,
28+
torch.ops.aten.zeros_like.default,
29+
)
2430

2531
def call(self, exported_program: torch.export.ExportedProgram):
2632
graph = exported_program.graph_module.graph

ai_edge_torch/generative/fx_passes/test/test_remove_sdpa_zero_mask_pass.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,9 @@ def forward(self, *args, **kwargs):
4141
module = func
4242

4343
exported_program = torch.export.export(module, export_args)
44+
exported_program = fx_infra.safe_run_decompositions(
45+
exported_program, fx_infra.decomp.pre_convert_decomp()
46+
)
4447
exported_program = fx_infra.run_passes(
4548
exported_program,
4649
[

0 commit comments

Comments
 (0)