Skip to content

Commit 0661ecd

Browse files
bobrenjc93pytorchmergebot
authored andcommitted
add support for hint_override in mark_unbacked (pytorch#162652)
Very similar to pytorch#161007 except now for mark_unbacked. Pull Request resolved: pytorch#162652 Approved by: https://github.com/laithsakka
1 parent 7a0f933 commit 0661ecd

File tree

4 files changed

+92
-12
lines changed

4 files changed

+92
-12
lines changed

test/inductor/test_torchinductor.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10696,6 +10696,44 @@ def override(x):
1069610696

1069710697
self.assertEqual(no_override(x_small), override(x_small))
1069810698

10699+
@requires_gpu()
10700+
@skip_if_not_triton
10701+
@unittest.skipIf(
10702+
not IS_BIG_GPU, "Skipping triton backend only since not big GPU (not enough SM)"
10703+
)
10704+
@config.patch({"force_disable_caches": True})
10705+
def test_mark_unbacked_with_hint_override(self):
10706+
@torch.compile
10707+
def no_override(x):
10708+
return x.sum(dim=0)
10709+
10710+
@torch.compile
10711+
def override(x):
10712+
return x.sum(dim=0)
10713+
10714+
@torch.compile(fullgraph=True)
10715+
def branching(x):
10716+
if x.shape[0] > 4096:
10717+
return 1
10718+
return 2
10719+
10720+
x_small = torch.randn(4096, 512, device=GPU_TYPE)
10721+
torch._dynamo.decorators.mark_unbacked(x_small, 0)
10722+
code1 = run_and_get_triton_code(no_override, x_small)
10723+
10724+
torch._dynamo.reset_code_caches()
10725+
10726+
torch._dynamo.decorators.mark_unbacked(x_small, 0, hint_override=4096 * 10)
10727+
code2 = run_and_get_triton_code(override, x_small)
10728+
self.assertNotEqual(code1, code2)
10729+
10730+
self.assertEqual(no_override(x_small), override(x_small))
10731+
10732+
with self.assertRaisesRegex(
10733+
RuntimeError, "Could not guard on data-dependent expression"
10734+
):
10735+
branching(x_small)
10736+
1069910737
@requires_gpu()
1070010738
def test_stride_preservation_with_stride_modifying_fx_pass(self):
1070110739
def f(x):

torch/_dynamo/decorators.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -536,19 +536,26 @@ class directly; instead, use :func:`mark_dynamic`.
536536
def mark_unbacked(
537537
t: Any,
538538
index: Union[int, list[Any], tuple[Any]],
539+
hint_override: Optional[int] = None,
539540
strict: bool = False,
540541
specialize_on: Optional[list[Any]] = None,
541542
) -> None:
542543
"""
543-
Mark a tensor as having an unbacked dim. This changes the semantics of operations,
544-
we will always report the size does not equal zero/one, we will turn asserts
545-
on this index into runtime asserts, and if you try to get the real value we will
546-
raise an exception. In other words, we will treat this dimension as if it was
547-
data dependent (we do not know anything about its value.)
544+
Mark a tensor as having an unbacked dimension. This changes the semantics of operations:
545+
- The size of the specified dimension will always be reported as not equal to zero or one.
546+
- Assertions on this index will be turned into runtime asserts.
547+
- Attempting to get the real value of this dimension will raise an exception.
548+
- In effect, this dimension is treated as data-dependent (its value is unknown).
548549
549-
For historical reasons, by default if an unbacked dim is specialized, we will
550-
happily specialize it and continue. If you want to error in these cases, pass
551-
strict=True.
550+
Args:
551+
t (Any): The tensor to mark as having an unbacked dimension.
552+
index (int or list/tuple of int): The dimension(s) to mark as unbacked. Can be a single integer or a list/tuple of integers.
553+
hint_override (Optional[int], default=None): An optional integer to override the size hint for this dimension.
554+
This is only used by the inductor backend for size hint queries, such as during autotuning.
555+
strict (bool, default=False): If True, an error will be raised if the unbacked dimension is specialized.
556+
By default (strict=False), specialization is allowed and will proceed without error.
557+
specialize_on (Optional[list[Any]], default=None): A list of specialization criteria (e.g., lambdas) for this dimension.
558+
If provided, Dynamo will generate specialized compiled regions for each criterion in addition to a generic trace.
552559
"""
553560
# You could have copied the mark_dynamic behavior but I'm not convinced
554561
# it's what you want
@@ -567,6 +574,12 @@ def mark_unbacked(
567574
if not hasattr(t, "_dynamo_unbacked_indices"):
568575
t._dynamo_unbacked_indices = set()
569576

577+
if not hasattr(t, "_dynamo_hint_overrides"):
578+
t._dynamo_hint_overrides = {}
579+
580+
if hint_override:
581+
t._dynamo_hint_overrides[index] = hint_override
582+
570583
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
571584
# TypeError: 'Attribute' object does not support item assignment
572585
if isinstance(t._specialize_on, dict):
@@ -612,7 +625,10 @@ def mark_dynamic(
612625
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
613626
before torch.compile.
614627
615-
5) If specialize_on is passed in, we will perform a single generic Dynamo trace followed by
628+
5) If hint_override is passed, the hint_override for the specified dimension will replace the provided value
629+
from the first example input as the official size hint.
630+
631+
6) If specialize_on is passed in, we will perform a single generic Dynamo trace followed by
616632
multiple specialized compilations in addition to a single generic compilation. NB: For now we only support
617633
per dimension specialization, or in other words we do not generate a cross product of specializations.
618634
At runtime, we will dispatch to a specialized compiled region if the input matches the specialization criteria.
@@ -626,6 +642,7 @@ def mark_dynamic(
626642
This approach results in one Dynamo trace and two backend compilations. When the input dimension equals 8 or 16
627643
at runtime, execution will be directed to the specialized compiled region. Performance measurements indicate
628644
2-8x speedups depending on the specific specialization and model architecture.
645+
629646
"""
630647
if is_traceable_wrapper_subclass(t):
631648
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t

torch/_inductor/sizevars.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ def __init__(self, shape_env=None) -> None:
7474
shape_env = ShapeEnv()
7575
self.shape_env = shape_env
7676
self.var_to_val = self.shape_env.var_to_val
77+
self.var_to_hint_override = self.shape_env.var_to_hint_override
7778
self.replacements: dict[sympy.Symbol, Expr] = self.shape_env.replacements
7879
self.unbacked_replacements: Optional[dict[Expr, Expr]] = None
7980
# Maps of dynamic sizes that have to be precomputed on the host to the kernel args.
@@ -544,7 +545,13 @@ def remove_precomputed_replacements(self, expr: Expr) -> Expr:
544545
return expr
545546

546547
def symbolic_hint(
547-
self, expr: Union[Expr, int], hint_override: Optional[int] = None
548+
self,
549+
expr: Union[Expr, int],
550+
hint_override: Optional[int] = None,
551+
# Only flip this flag if you don't plan on guarding/adding runtime
552+
# asserts based on this value and promise to only use this value
553+
# in a heuristic nature.
554+
use_user_provided_hint_override: bool = False,
548555
) -> Union[Expr, int]:
549556
if isinstance(expr, int):
550557
return expr
@@ -564,6 +571,10 @@ def symbolic_hint(
564571
return hint_override
565572

566573
expr = self.remove_precomputed_replacements(expr)
574+
575+
if use_user_provided_hint_override:
576+
expr = sympy_subs(expr, self.var_to_hint_override)
577+
567578
return sympy_subs(expr, self.var_to_val)
568579

569580
def size_hint(
@@ -573,7 +584,11 @@ def size_hint(
573584
fallback: Optional[int] = None,
574585
hint_override: Optional[int] = None,
575586
) -> int:
576-
out = self.symbolic_hint(expr, hint_override=hint_override)
587+
out = self.symbolic_hint(
588+
expr,
589+
hint_override=hint_override,
590+
use_user_provided_hint_override=fallback is not None,
591+
)
577592
if not isinstance(out, (int, sympy.Integer)) and fallback is not None:
578593
# Use the provided heuristic fallback hint
579594
unbacked_sym_vrs = {
@@ -610,7 +625,11 @@ def size_hints(
610625
hint_override: Optional[int] = None,
611626
) -> tuple[int, ...]:
612627
return tuple(
613-
self.size_hint(x, fallback=fallback, hint_override=hint_override)
628+
self.size_hint(
629+
x,
630+
fallback=fallback,
631+
hint_override=hint_override,
632+
)
614633
for x in exprs
615634
)
616635

torch/fx/experimental/symbolic_shapes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3718,6 +3718,7 @@ def _init(
37183718
self.source_name_to_debug_name: dict[str, str] = {}
37193719
self.var_to_sources: dict[sympy.Symbol, list[Source]] = {}
37203720
self.var_to_stack: dict[sympy.Symbol, CapturedTraceback] = {}
3721+
self.var_to_hint_override: dict[sympy.Symbol, int] = {}
37213722
# Maps a source to the *original* symbol that was assigned to it
37223723
self.source_to_var: dict[str, sympy.Symbol] = {}
37233724
# Maps from sympy ints to expressions representing them
@@ -4582,6 +4583,11 @@ def _create_symbolic_sizes_strides_storage_offset(
45824583
)
45834584
for i, (sym, hint) in enumerate(zip(size, ex_size))
45844585
]
4586+
4587+
for i, sym in enumerate(sym_sizes):
4588+
if isinstance(sym, torch.SymInt) and i in hint_overrides:
4589+
self.var_to_hint_override[sym.node.expr] = hint_overrides[i]
4590+
45854591
sym_stride = []
45864592
for i, stride_expr in enumerate(stride):
45874593
# NB: Don't duck size the stride; instead use the expression

0 commit comments

Comments
 (0)