Skip to content

Commit fe0ddf1

Browse files
authored
Allow pattern replacement to ignore literals (#2519)
* When replacing literals with placeholders lists are always converted to tuples Summary: THis is needed because lists are not hashable, since they are mutable, and as a result we cannot have literals_to_ph in pattern rewrites used inside reference_representation_rewrite.py Test Plan: CI + next diff relies on this feature Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Allow pattern replacement to ignore literals Summary: This is necessary because sometimes the patterns found have literals include tuple of ints kind of literals. This values shouldnt be used for pattern matching since often they are based on consts derived from example inputs. THis is not exactly a safe thing to do in general so by default it is turned off Test Plan: Subsequent diff adds a pattern that relies on this Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned] * Update on "Allow pattern replacement to ignore literals" Summary: This is necessary because sometimes the patterns found have literals include tuple of ints kind of literals. This values shouldnt be used for pattern matching since often they are based on consts derived from example inputs. THis is not exactly a safe thing to do in general so by default it is turned off Test Plan: Subsequent diff adds a pattern that relies on this Reviewers: Subscribers: Tasks: Tags: [ghstack-poisoned]
1 parent 510e1b4 commit fe0ddf1

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

torchao/quantization/pt2e/reference_representation_rewrite.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from torch._higher_order_ops.out_dtype import out_dtype
1515
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa: F401
1616
from torch.fx import GraphModule
17-
from torch.fx.subgraph_rewriter import replace_pattern
17+
from torch.fx.subgraph_rewriter import replace_pattern_with_filters
1818

1919
from torchao.quantization.pt2e.export_utils import WrapperModule
2020
from torchao.quantization.pt2e.utils import (
@@ -627,6 +627,7 @@ class _RewriteInfo:
627627
# post transformation on the exported pattern and replacement GraphModule
628628
pattern_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
629629
replacement_post_trans: Optional[Callable[[GraphModule], GraphModule]] = None
630+
ignore_literals: bool = False
630631

631632

632633
def reference_representation_rewrite(model: GraphModule) -> GraphModule:
@@ -830,6 +831,12 @@ def reference_representation_rewrite(model: GraphModule) -> GraphModule:
830831
replacement = replacement_post_trans(replacement)
831832
pattern.recompile() # type: ignore[attr-defined]
832833
replacement.recompile() # type: ignore[attr-defined]
833-
replace_pattern(model, pattern, replacement)
834+
replace_pattern_with_filters(
835+
model,
836+
pattern,
837+
replacement,
838+
match_filters=None,
839+
ignore_literals=rewrite_info.ignore_literals,
840+
) # type: ignore[arg-type]
834841

835842
return model

0 commit comments

Comments
 (0)