6464 compute_unbacked_bindings ,
6565 free_symbols ,
6666 free_unbacked_symbols ,
67+ IterateExprs ,
6768 rebind_unbacked ,
6869 resolve_unbacked_bindings ,
6970 ShapeEnv ,
9798 argsort ,
9899 argsort_sym ,
99100 cache_on_self ,
101+ cache_on_self_and_args ,
100102 ceildiv ,
101103 convert_shape_to_inductor ,
102104 convert_shape_to_symint ,
@@ -933,6 +935,7 @@ class Loops(IRNode):
933935 inner_fn : Callable [..., Any ]
934936 ranges : Sequence [_IntLike ]
935937
938+ @cache_on_self_and_args ("Loops" )
936939 def get_free_symbol_uses (
937940 self , unbacked_only : bool = False
938941 ) -> OrderedSet [sympy .Symbol ]:
@@ -1222,6 +1225,7 @@ def __str__(self) -> str:
12221225
12231226 __repr__ = __str__
12241227
1228+ @cache_on_self_and_args ("Reduction" )
12251229 def get_free_symbol_uses (self , unbacked_only : bool = False ) -> OrderedSet [Symbol ]:
12261230 return super ().get_free_symbol_uses (unbacked_only ) | OrderedSet ().union (
12271231 * (get_free_symbols (e , unbacked_only ) for e in self .reduction_ranges )
@@ -2311,6 +2315,7 @@ class Scan(Loops):
23112315
23122316 # HACK we mimic reduction
23132317
2318+ @cache_on_self_and_args ("Scan" )
23142319 def get_free_symbol_uses (self , unbacked_only : bool = False ) -> OrderedSet [Symbol ]:
23152320 # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
23162321 # need to explicitly represent the closure so we can pull out unbacked
@@ -2520,6 +2525,7 @@ class Sort(Loops):
25202525
25212526 # HACK we mimic reduction
25222527
2528+ @cache_on_self_and_args ("Sort" )
25232529 def get_free_symbol_uses (self , unbacked_only : bool = False ) -> OrderedSet [Symbol ]:
25242530 return (
25252531 super ().get_free_symbol_uses (unbacked_only )
@@ -2768,6 +2774,7 @@ def is_unaligned(node: IRNode) -> bool:
27682774class BaseView (IRNode ):
27692775 data : IRNode
27702776
2777+ @cache_on_self_and_args ("BaseView" )
27712778 def get_free_symbol_uses (self , unbacked_only : bool = False ) -> OrderedSet [Symbol ]:
27722779 return self .data .get_free_symbol_uses (unbacked_only )
27732780
@@ -3334,6 +3341,7 @@ def get_layout(self) -> Layout:
33343341 def freeze_layout (self ) -> None :
33353342 pass
33363343
3344+ @cache_on_self_and_args ("ReinterpretView" )
33373345 def get_free_symbol_uses (
33383346 self , unbacked_only : bool = False
33393347 ) -> OrderedSet [sympy .Symbol ]:
@@ -3617,13 +3625,37 @@ def __init__(
36173625 self .dtype = dtype
36183626 assert len (size ) == len (stride ), f"size={ size } , stride={ stride } "
36193627 assert all (isinstance (s , (Expr , int )) for s in size )
3620- self .size = size
3621- self .stride = stride
3622- self .offset = offset
3628+ self ._size = size
3629+ self ._stride = stride
3630+ self ._offset = offset
36233631 self .is_pinned = is_pinned
36243632 # is_pinned implies cpu
36253633 assert (not self .is_pinned ) or (self .device .type == "cpu" )
36263634
3635+ @property
3636+ def size (self ) -> Sequence [Expr ]:
3637+ return self ._size
3638+
3639+ @size .setter
3640+ def size (self , value : Sequence [Expr ]) -> None :
3641+ self ._size = value
3642+
3643+ @property
3644+ def stride (self ) -> Sequence [Expr ]:
3645+ return self ._stride
3646+
3647+ @stride .setter
3648+ def stride (self , value : Sequence [Expr ]) -> None :
3649+ self ._stride = value
3650+
3651+ @property
3652+ def offset (self ) -> Expr :
3653+ return self ._offset
3654+
3655+ @offset .setter
3656+ def offset (self , value : Expr ) -> None :
3657+ self ._offset = value
3658+
36273659 def __str__ (self ) -> str :
36283660 offset = ""
36293661 if self .offset != 0 :
@@ -3833,6 +3865,7 @@ def __eq__(self, other: object) -> bool:
38333865 def storage_size (self ) -> Expr :
38343866 return compute_required_storage_length (self .size , self .stride , self .offset ) # type: ignore[arg-type]
38353867
3868+ @cache_on_self_and_args ("Layout" )
38363869 def get_free_symbol_uses (
38373870 self , unbacked_only : bool = False
38383871 ) -> OrderedSet [sympy .Symbol ]:
@@ -3852,7 +3885,11 @@ def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
38523885
38533886
38543887class FlexibleLayout (Layout ):
3855- """A Tensor layout that we are allowed to change"""
3888+ """
3889+ A Tensor layout that we are allowed to change
3890+
3891+ Assumption: layout change should NOT add or remove free symbols
3892+ """
38563893
38573894 allow_indexing = False
38583895
@@ -3937,6 +3974,33 @@ def same_ordered(
39373974 fill_order = sorted (range (len (stride )), key = stride .__getitem__ )
39383975 return FlexibleLayout .fill_ordered (sizes , fill_order )
39393976
3977+ @property
3978+ def size (self ) -> Sequence [Expr ]:
3979+ return self ._size
3980+
3981+ @size .setter
3982+ def size (self , value : Sequence [Expr ]) -> None :
3983+ self .assert_free_symbol_uses_unchanged ("size" , value )
3984+ self ._size = value
3985+
3986+ @property
3987+ def stride (self ) -> Sequence [Expr ]:
3988+ return self ._stride
3989+
3990+ @stride .setter
3991+ def stride (self , value : Sequence [Expr ]) -> None :
3992+ self .assert_free_symbol_uses_unchanged ("stride" , value )
3993+ self ._stride = value
3994+
3995+ @property
3996+ def offset (self ) -> Expr :
3997+ return self ._offset
3998+
3999+ @offset .setter
4000+ def offset (self , value : Expr ) -> None :
4001+ self .assert_free_symbol_uses_unchanged ("offset" , value )
4002+ self ._offset = value
4003+
39404004 def as_stride_order (
39414005 self , order : Sequence [int ], allow_padding : bool = False
39424006 ) -> FixedLayout :
@@ -3995,6 +4059,25 @@ def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout:
39954059 self .is_pinned ,
39964060 )
39974061
4062+ def get_initial_free_symbol_uses (self ) -> dict [tuple [str , bool ], sympy .Symbol ]:
4063+ initial_free_symbols = {}
4064+ for name in ["size" , "stride" , "offset" ]:
4065+ for unbacked_only in [True , False ]:
4066+ key = (name , unbacked_only )
4067+ initial_free_symbols [key ] = OrderedSet (
4068+ get_free_symbols (getattr (self , name ), unbacked_only )
4069+ )
4070+
4071+ return initial_free_symbols
4072+
4073+ def assert_free_symbol_uses_unchanged (self , name : str , value : IterateExprs ) -> None :
4074+ for unbacked_only in [True , False ]:
4075+ old_free_symbols = self .initial_free_symbols [(name , unbacked_only )]
4076+ new_free_symbols = OrderedSet (get_free_symbols (value , unbacked_only ))
4077+ assert new_free_symbols == old_free_symbols , (
4078+ f"Expected free symbols unchanged, but got { new_free_symbols } vs { old_free_symbols } "
4079+ )
4080+
39984081 def __init__ (
39994082 self ,
40004083 device : torch .device ,
@@ -4009,6 +4092,10 @@ def __init__(
40094092 strides = FlexibleLayout .contiguous_strides (size )
40104093 super ().__init__ (device , dtype , size , strides , is_pinned = is_pinned )
40114094
4095+ # record the initial free symbols to check that we do not add new free symbols
4096+ # later when modifying sizes, strides, and offsets.
4097+ self .initial_free_symbols = self .get_initial_free_symbol_uses ()
4098+
40124099
40134100class NonOwningLayout (Layout ):
40144101 """Is a view into the storage of another tensor"""
@@ -4034,6 +4121,7 @@ def maybe_guard_aligned(self) -> bool:
40344121
40354122 return V .graph .sizevars .statically_known_multiple_of (offset , ALIGNMENT )
40364123
4124+ @cache_on_self_and_args ("NonOwningLayout" )
40374125 def get_free_symbol_uses (
40384126 self , unbacked_only : bool = False
40394127 ) -> OrderedSet [sympy .Symbol ]:
@@ -4322,6 +4410,7 @@ def get_mutation_names(self) -> Sequence[str]:
43224410 def get_read_names (self ) -> OrderedSet [str ]:
43234411 return OrderedSet ([self .get_name ()])
43244412
4413+ @cache_on_self_and_args ("Buffer" )
43254414 def get_free_symbol_uses (
43264415 self , unbacked_only : bool = False
43274416 ) -> OrderedSet [sympy .Symbol ]:
@@ -4394,6 +4483,7 @@ class NoneAsConstantBuffer(IRNode):
43944483 def get_reads (self ) -> OrderedSet [Dep ]:
43954484 return OrderedSet ()
43964485
4486+ @cache_on_self_and_args ("NoneAsConstantBuffer" )
43974487 def get_free_symbol_uses (
43984488 self , unbacked_only : bool = False
43994489 ) -> OrderedSet [sympy .Symbol ]:
@@ -4413,6 +4503,7 @@ def has_tensor_output(self) -> bool:
44134503class ShapeAsConstantBuffer (IRNode ):
44144504 expr : Expr
44154505
4506+ @cache_on_self_and_args ("ShapeAsConstantBuffer" )
44164507 def get_free_symbol_uses (
44174508 self , unbacked_only : bool = False
44184509 ) -> OrderedSet [sympy .Symbol ]:
@@ -4485,6 +4576,7 @@ def get_read_writes(self) -> dependencies.ReadWrites:
44854576 self .data .get_size (),
44864577 )
44874578
4579+ @cache_on_self_and_args ("ComputedBuffer" )
44884580 def get_free_symbol_uses (
44894581 self , unbacked_only : bool = False
44904582 ) -> OrderedSet [sympy .Symbol ]:
@@ -4912,6 +5004,7 @@ def __init__(
49125004 self .subgraph_inps : Optional [list [Optional [Union [IRNode , sympy .Expr ]]]] = None
49135005 self .subgraph_outs : Optional [list [Optional [IRNode ]]] = None
49145006
5007+ @cache_on_self_and_args ("TritonTemplateBuffer" )
49155008 def get_free_symbol_uses (
49165009 self , unbacked_only : bool = False
49175010 ) -> OrderedSet [sympy .Symbol ]:
@@ -5264,6 +5357,7 @@ def is_extern(self) -> bool:
52645357 def num_reads (self ) -> int :
52655358 return 1
52665359
5360+ @cache_on_self_and_args ("InputsKernel" )
52675361 def get_free_symbol_uses (
52685362 self , unbacked_only : bool = False
52695363 ) -> OrderedSet [sympy .Symbol ]:
@@ -5438,6 +5532,7 @@ def can_realize_into_without_copy(
54385532 and not isinstance (src .data , ExternKernelAlloc )
54395533 )
54405534
5535+ @cache_on_self_and_args ("ConcatKernel" )
54415536 def get_free_symbol_uses (
54425537 self , unbacked_only : bool = False
54435538 ) -> OrderedSet [sympy .Symbol ]:
@@ -6337,6 +6432,7 @@ def canonicalize(self) -> tuple[Expr, Sequence[Expr]]:
63376432 index = sympy_subs (sympy .expand (index ), replacement )
63386433 return index , tuple (new_sizes )
63396434
6435+ @cache_on_self_and_args ("ExternKernel" )
63406436 def get_free_symbol_uses (
63416437 self , unbacked_only : bool = False
63426438 ) -> OrderedSet [sympy .Symbol ]:
@@ -6797,6 +6893,7 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None:
67976893 original_fxnode_name = self .fx_node .name ,
67986894 )
67996895
6896+ @cache_on_self_and_args ("UserDefinedTritonKernel" )
68006897 def get_free_symbol_uses (
68016898 self , unbacked_only : bool = False
68026899 ) -> OrderedSet [sympy .Symbol ]:
@@ -7265,6 +7362,7 @@ def __init__(
72657362 def get_unbacked_symbol_defs (self ) -> OrderedSet [sympy .Symbol ]:
72667363 return OrderedSet ([self .unbacked_offset_symbol ])
72677364
7365+ @cache_on_self_and_args ("DynamicSelectStorageOffset" )
72687366 def get_free_symbol_uses (
72697367 self , unbacked_only : bool = False
72707368 ) -> OrderedSet [sympy .Symbol ]:
@@ -7327,6 +7425,7 @@ def __init__(self, scalar: SympyBoolean, msg: str) -> None:
73277425 def has_side_effects (self ) -> bool :
73287426 return True
73297427
7428+ @cache_on_self_and_args ("AssertScalar" )
73307429 def get_free_symbol_uses (
73317430 self , unbacked_only : bool = False
73327431 ) -> OrderedSet [sympy .Symbol ]:
@@ -7999,6 +8098,7 @@ def __init__(
79998098 self .indices = indices
80008099 self .skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
80018100
8101+ @cache_on_self_and_args ("MultiOutput" )
80028102 def get_free_symbol_uses (
80038103 self , unbacked_only : bool = False
80048104 ) -> OrderedSet [sympy .Symbol ]:
@@ -8121,6 +8221,7 @@ def get_inputs_that_alias_output(self) -> Sequence[str]:
81218221 def realize (self ) -> Optional [str ]:
81228222 return self .data .realize ()
81238223
8224+ @cache_on_self_and_args ("MutableBox" )
81248225 def get_free_symbol_uses (
81258226 self , unbacked_only : bool = False
81268227 ) -> OrderedSet [sympy .Symbol ]:
@@ -8919,6 +9020,7 @@ def has_side_effects(self) -> bool:
89199020
89209021
89219022class NonTensorObj (IRNode ):
9023+ @cache_on_self_and_args ("NonTensorObj" )
89229024 def get_free_symbol_uses (
89239025 self , unbacked_only : bool = False
89249026 ) -> OrderedSet [sympy .Symbol ]:
0 commit comments