Skip to content

Commit 3d27d95

Browse files
[GraphPartition] cache get_free_symbol_uses (pytorch#166338) (pytorch#166994)
Graph partition relies on `get_free_symbol_uses()` to collect symbol inputs. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/scheduler.py#L4869-L4885 I empirically observed that `get_free_symbol_uses()` becomes slower for larger graphs. Specifically, I tried to aten fallback for torchtitan which results in 10k+ aten nodes. When processing the 600-th node, it takes seconds to `get_free_symbol_uses()` for 1 node. Why? Because `get_free_symbol_uses()` may recursively call another `get_free_symbol_uses()`, which could recursively run many times. https://github.com/pytorch/pytorch/blob/ee7434be822cf6e75b4566d8159f550ee233d8ae/torch/_inductor/ir.py#L4541-L4543 This PR fixes the issue by caching the results of `get_free_symbol_uses()`. I validated on torchtitan that the issue is fixed. Pull Request resolved: pytorch#166338 Approved by: https://github.com/eellison (cherry picked from commit dfebdca) Co-authored-by: Boyuan Feng <[email protected]>
1 parent a06141f commit 3d27d95

File tree

2 files changed

+161
-4
lines changed

2 files changed

+161
-4
lines changed

torch/_inductor/ir.py

Lines changed: 106 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
compute_unbacked_bindings,
6565
free_symbols,
6666
free_unbacked_symbols,
67+
IterateExprs,
6768
rebind_unbacked,
6869
resolve_unbacked_bindings,
6970
ShapeEnv,
@@ -97,6 +98,7 @@
9798
argsort,
9899
argsort_sym,
99100
cache_on_self,
101+
cache_on_self_and_args,
100102
ceildiv,
101103
convert_shape_to_inductor,
102104
convert_shape_to_symint,
@@ -933,6 +935,7 @@ class Loops(IRNode):
933935
inner_fn: Callable[..., Any]
934936
ranges: Sequence[_IntLike]
935937

938+
@cache_on_self_and_args("Loops")
936939
def get_free_symbol_uses(
937940
self, unbacked_only: bool = False
938941
) -> OrderedSet[sympy.Symbol]:
@@ -1222,6 +1225,7 @@ def __str__(self) -> str:
12221225

12231226
__repr__ = __str__
12241227

1228+
@cache_on_self_and_args("Reduction")
12251229
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
12261230
return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
12271231
*(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
@@ -2311,6 +2315,7 @@ class Scan(Loops):
23112315

23122316
# HACK we mimic reduction
23132317

2318+
@cache_on_self_and_args("Scan")
23142319
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
23152320
# TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
23162321
# need to explicitly represent the closure so we can pull out unbacked
@@ -2520,6 +2525,7 @@ class Sort(Loops):
25202525

25212526
# HACK we mimic reduction
25222527

2528+
@cache_on_self_and_args("Sort")
25232529
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
25242530
return (
25252531
super().get_free_symbol_uses(unbacked_only)
@@ -2768,6 +2774,7 @@ def is_unaligned(node: IRNode) -> bool:
27682774
class BaseView(IRNode):
27692775
data: IRNode
27702776

2777+
@cache_on_self_and_args("BaseView")
27712778
def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
27722779
return self.data.get_free_symbol_uses(unbacked_only)
27732780

@@ -3334,6 +3341,7 @@ def get_layout(self) -> Layout:
33343341
def freeze_layout(self) -> None:
33353342
pass
33363343

3344+
@cache_on_self_and_args("ReinterpretView")
33373345
def get_free_symbol_uses(
33383346
self, unbacked_only: bool = False
33393347
) -> OrderedSet[sympy.Symbol]:
@@ -3617,13 +3625,37 @@ def __init__(
36173625
self.dtype = dtype
36183626
assert len(size) == len(stride), f"size={size}, stride={stride}"
36193627
assert all(isinstance(s, (Expr, int)) for s in size)
3620-
self.size = size
3621-
self.stride = stride
3622-
self.offset = offset
3628+
self._size = size
3629+
self._stride = stride
3630+
self._offset = offset
36233631
self.is_pinned = is_pinned
36243632
# is_pinned implies cpu
36253633
assert (not self.is_pinned) or (self.device.type == "cpu")
36263634

3635+
@property
3636+
def size(self) -> Sequence[Expr]:
3637+
return self._size
3638+
3639+
@size.setter
3640+
def size(self, value: Sequence[Expr]) -> None:
3641+
self._size = value
3642+
3643+
@property
3644+
def stride(self) -> Sequence[Expr]:
3645+
return self._stride
3646+
3647+
@stride.setter
3648+
def stride(self, value: Sequence[Expr]) -> None:
3649+
self._stride = value
3650+
3651+
@property
3652+
def offset(self) -> Expr:
3653+
return self._offset
3654+
3655+
@offset.setter
3656+
def offset(self, value: Expr) -> None:
3657+
self._offset = value
3658+
36273659
def __str__(self) -> str:
36283660
offset = ""
36293661
if self.offset != 0:
@@ -3833,6 +3865,7 @@ def __eq__(self, other: object) -> bool:
38333865
def storage_size(self) -> Expr:
38343866
return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
38353867

3868+
@cache_on_self_and_args("Layout")
38363869
def get_free_symbol_uses(
38373870
self, unbacked_only: bool = False
38383871
) -> OrderedSet[sympy.Symbol]:
@@ -3852,7 +3885,11 @@ def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
38523885

38533886

38543887
class FlexibleLayout(Layout):
3855-
"""A Tensor layout that we are allowed to change"""
3888+
"""
3889+
A Tensor layout that we are allowed to change
3890+
3891+
Assumption: layout change should NOT add or remove free symbols
3892+
"""
38563893

38573894
allow_indexing = False
38583895

@@ -3937,6 +3974,33 @@ def same_ordered(
39373974
fill_order = sorted(range(len(stride)), key=stride.__getitem__)
39383975
return FlexibleLayout.fill_ordered(sizes, fill_order)
39393976

3977+
@property
3978+
def size(self) -> Sequence[Expr]:
3979+
return self._size
3980+
3981+
@size.setter
3982+
def size(self, value: Sequence[Expr]) -> None:
3983+
self.assert_free_symbol_uses_unchanged("size", value)
3984+
self._size = value
3985+
3986+
@property
3987+
def stride(self) -> Sequence[Expr]:
3988+
return self._stride
3989+
3990+
@stride.setter
3991+
def stride(self, value: Sequence[Expr]) -> None:
3992+
self.assert_free_symbol_uses_unchanged("stride", value)
3993+
self._stride = value
3994+
3995+
@property
3996+
def offset(self) -> Expr:
3997+
return self._offset
3998+
3999+
@offset.setter
4000+
def offset(self, value: Expr) -> None:
4001+
self.assert_free_symbol_uses_unchanged("offset", value)
4002+
self._offset = value
4003+
39404004
def as_stride_order(
39414005
self, order: Sequence[int], allow_padding: bool = False
39424006
) -> FixedLayout:
@@ -3995,6 +4059,25 @@ def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout:
39954059
self.is_pinned,
39964060
)
39974061

4062+
def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
4063+
initial_free_symbols = {}
4064+
for name in ["size", "stride", "offset"]:
4065+
for unbacked_only in [True, False]:
4066+
key = (name, unbacked_only)
4067+
initial_free_symbols[key] = OrderedSet(
4068+
get_free_symbols(getattr(self, name), unbacked_only)
4069+
)
4070+
4071+
return initial_free_symbols
4072+
4073+
def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
4074+
for unbacked_only in [True, False]:
4075+
old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
4076+
new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
4077+
assert new_free_symbols == old_free_symbols, (
4078+
f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
4079+
)
4080+
39984081
def __init__(
39994082
self,
40004083
device: torch.device,
@@ -4009,6 +4092,10 @@ def __init__(
40094092
strides = FlexibleLayout.contiguous_strides(size)
40104093
super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
40114094

4095+
# record the initial free symbols to check that we do not add new free symbols
4096+
# later when modifying sizes, strides, and offsets.
4097+
self.initial_free_symbols = self.get_initial_free_symbol_uses()
4098+
40124099

40134100
class NonOwningLayout(Layout):
40144101
"""Is a view into the storage of another tensor"""
@@ -4034,6 +4121,7 @@ def maybe_guard_aligned(self) -> bool:
40344121

40354122
return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
40364123

4124+
@cache_on_self_and_args("NonOwningLayout")
40374125
def get_free_symbol_uses(
40384126
self, unbacked_only: bool = False
40394127
) -> OrderedSet[sympy.Symbol]:
@@ -4322,6 +4410,7 @@ def get_mutation_names(self) -> Sequence[str]:
43224410
def get_read_names(self) -> OrderedSet[str]:
43234411
return OrderedSet([self.get_name()])
43244412

4413+
@cache_on_self_and_args("Buffer")
43254414
def get_free_symbol_uses(
43264415
self, unbacked_only: bool = False
43274416
) -> OrderedSet[sympy.Symbol]:
@@ -4394,6 +4483,7 @@ class NoneAsConstantBuffer(IRNode):
43944483
def get_reads(self) -> OrderedSet[Dep]:
43954484
return OrderedSet()
43964485

4486+
@cache_on_self_and_args("NoneAsConstantBuffer")
43974487
def get_free_symbol_uses(
43984488
self, unbacked_only: bool = False
43994489
) -> OrderedSet[sympy.Symbol]:
@@ -4413,6 +4503,7 @@ def has_tensor_output(self) -> bool:
44134503
class ShapeAsConstantBuffer(IRNode):
44144504
expr: Expr
44154505

4506+
@cache_on_self_and_args("ShapeAsConstantBuffer")
44164507
def get_free_symbol_uses(
44174508
self, unbacked_only: bool = False
44184509
) -> OrderedSet[sympy.Symbol]:
@@ -4485,6 +4576,7 @@ def get_read_writes(self) -> dependencies.ReadWrites:
44854576
self.data.get_size(),
44864577
)
44874578

4579+
@cache_on_self_and_args("ComputedBuffer")
44884580
def get_free_symbol_uses(
44894581
self, unbacked_only: bool = False
44904582
) -> OrderedSet[sympy.Symbol]:
@@ -4912,6 +5004,7 @@ def __init__(
49125004
self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
49135005
self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
49145006

5007+
@cache_on_self_and_args("TritonTemplateBuffer")
49155008
def get_free_symbol_uses(
49165009
self, unbacked_only: bool = False
49175010
) -> OrderedSet[sympy.Symbol]:
@@ -5264,6 +5357,7 @@ def is_extern(self) -> bool:
52645357
def num_reads(self) -> int:
52655358
return 1
52665359

5360+
@cache_on_self_and_args("InputsKernel")
52675361
def get_free_symbol_uses(
52685362
self, unbacked_only: bool = False
52695363
) -> OrderedSet[sympy.Symbol]:
@@ -5438,6 +5532,7 @@ def can_realize_into_without_copy(
54385532
and not isinstance(src.data, ExternKernelAlloc)
54395533
)
54405534

5535+
@cache_on_self_and_args("ConcatKernel")
54415536
def get_free_symbol_uses(
54425537
self, unbacked_only: bool = False
54435538
) -> OrderedSet[sympy.Symbol]:
@@ -6337,6 +6432,7 @@ def canonicalize(self) -> tuple[Expr, Sequence[Expr]]:
63376432
index = sympy_subs(sympy.expand(index), replacement)
63386433
return index, tuple(new_sizes)
63396434

6435+
@cache_on_self_and_args("ExternKernel")
63406436
def get_free_symbol_uses(
63416437
self, unbacked_only: bool = False
63426438
) -> OrderedSet[sympy.Symbol]:
@@ -6797,6 +6893,7 @@ def codegen(self, wrapper: PythonWrapperCodegen) -> None:
67976893
original_fxnode_name=self.fx_node.name,
67986894
)
67996895

6896+
@cache_on_self_and_args("UserDefinedTritonKernel")
68006897
def get_free_symbol_uses(
68016898
self, unbacked_only: bool = False
68026899
) -> OrderedSet[sympy.Symbol]:
@@ -7265,6 +7362,7 @@ def __init__(
72657362
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
72667363
return OrderedSet([self.unbacked_offset_symbol])
72677364

7365+
@cache_on_self_and_args("DynamicSelectStorageOffset")
72687366
def get_free_symbol_uses(
72697367
self, unbacked_only: bool = False
72707368
) -> OrderedSet[sympy.Symbol]:
@@ -7327,6 +7425,7 @@ def __init__(self, scalar: SympyBoolean, msg: str) -> None:
73277425
def has_side_effects(self) -> bool:
73287426
return True
73297427

7428+
@cache_on_self_and_args("AssertScalar")
73307429
def get_free_symbol_uses(
73317430
self, unbacked_only: bool = False
73327431
) -> OrderedSet[sympy.Symbol]:
@@ -7999,6 +8098,7 @@ def __init__(
79998098
self.indices = indices
80008099
self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
80018100

8101+
@cache_on_self_and_args("MultiOutput")
80028102
def get_free_symbol_uses(
80038103
self, unbacked_only: bool = False
80048104
) -> OrderedSet[sympy.Symbol]:
@@ -8121,6 +8221,7 @@ def get_inputs_that_alias_output(self) -> Sequence[str]:
81218221
def realize(self) -> Optional[str]:
81228222
return self.data.realize()
81238223

8224+
@cache_on_self_and_args("MutableBox")
81248225
def get_free_symbol_uses(
81258226
self, unbacked_only: bool = False
81268227
) -> OrderedSet[sympy.Symbol]:
@@ -8919,6 +9020,7 @@ def has_side_effects(self) -> bool:
89199020

89209021

89219022
class NonTensorObj(IRNode):
9023+
@cache_on_self_and_args("NonTensorObj")
89229024
def get_free_symbol_uses(
89239025
self, unbacked_only: bool = False
89249026
) -> OrderedSet[sympy.Symbol]:

torch/_inductor/utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,7 @@ def sort_func(elem: _T) -> str:
626626

627627
P = ParamSpec("P")
628628
RV = TypeVar("RV", covariant=True)
629+
FN_TYPE = Callable[Concatenate[Any, P], RV]
629630

630631

631632
class CachedMethod(Protocol, Generic[P, RV]):
@@ -665,6 +666,60 @@ def clear_cache(self: Any) -> None:
665666
return wrapper # type: ignore[return-value]
666667

667668

669+
def cache_property_on_self(fn: Callable[P, RV]) -> CachedMethod[P, RV]:
670+
"""
671+
Variant of cache_on_self for properties. The only difference is the type signature.
672+
"""
673+
# pyrefly: ignore [bad-argument-type]
674+
return cache_on_self(fn)
675+
676+
677+
def cache_on_self_and_args(
678+
class_name: str,
679+
) -> Callable[[FN_TYPE[P, RV]], FN_TYPE[P, RV]]:
680+
# include both class_name and fn_name in the key to support `super().fn(self, **args, **kwargs)` calls.
681+
682+
def wrapper(
683+
fn: FN_TYPE[P, RV],
684+
) -> FN_TYPE[P, RV]:
685+
key = f"__{class_name}_{fn.__name__}_cache"
686+
687+
# wrapper is likely on the hot path, compile a specialized version of it
688+
ctx = {"fn": fn}
689+
exec(
690+
f"""\
691+
def inner(self: Any, *args: P.args, **kwargs: P.kwargs) -> RV:
692+
args_kwargs = (args, tuple(sorted(kwargs.items())))
693+
694+
if not hasattr(self, "{key}"):
695+
object.__setattr__(self, "{key}", {{}})
696+
697+
cache = self.{key}
698+
699+
try:
700+
return cache[args_kwargs]
701+
except KeyError:
702+
pass
703+
704+
rv = fn(self, *args, **kwargs)
705+
706+
cache[args_kwargs] = rv
707+
return rv
708+
""".lstrip(),
709+
ctx,
710+
)
711+
inner = functools.wraps(fn)(ctx["inner"])
712+
713+
def clear_cache(self: Any) -> None:
714+
if hasattr(self, key):
715+
delattr(self, key)
716+
717+
inner.clear_cache = clear_cache # type: ignore[attr-defined]
718+
return inner
719+
720+
return wrapper
721+
722+
668723
def aggregate_origins(
669724
node_schedule: Union[Sequence[BaseSchedulerNode], ExternKernel],
670725
) -> OrderedSet[Node]:

0 commit comments

Comments
 (0)