@@ -536,19 +536,26 @@ class directly; instead, use :func:`mark_dynamic`.
536
536
def mark_unbacked (
537
537
t : Any ,
538
538
index : Union [int , list [Any ], tuple [Any ]],
539
+ hint_override : Optional [int ] = None ,
539
540
strict : bool = False ,
540
541
specialize_on : Optional [list [Any ]] = None ,
541
542
) -> None :
542
543
"""
543
- Mark a tensor as having an unbacked dim. This changes the semantics of operations,
544
- we will always report the size does not equal zero/one, we will turn asserts
545
- on this index into runtime asserts, and if you try to get the real value we will
546
- raise an exception. In other words, we will treat this dimension as if it was
547
- data dependent (we do not know anything about its value.)
544
+ Mark a tensor as having an unbacked dimension. This changes the semantics of operations:
545
+ - The size of the specified dimension will always be reported as not equal to zero or one.
546
+ - Assertions on this index will be turned into runtime asserts.
547
+ - Attempting to get the real value of this dimension will raise an exception.
548
+ - In effect, this dimension is treated as data-dependent ( its value is unknown).
548
549
549
- For historical reasons, by default if an unbacked dim is specialized, we will
550
- happily specialize it and continue. If you want to error in these cases, pass
551
- strict=True.
550
+ Args:
551
+ t (Any): The tensor to mark as having an unbacked dimension.
552
+ index (int or list/tuple of int): The dimension(s) to mark as unbacked. Can be a single integer or a list/tuple of integers.
553
+ hint_override (Optional[int], default=None): An optional integer to override the size hint for this dimension.
554
+ This is only used by the inductor backend for size hint queries, such as during autotuning.
555
+ strict (bool, default=False): If True, an error will be raised if the unbacked dimension is specialized.
556
+ By default (strict=False), specialization is allowed and will proceed without error.
557
+ specialize_on (Optional[list[Any]], default=None): A list of specialization criteria (e.g., lambdas) for this dimension.
558
+ If provided, Dynamo will generate specialized compiled regions for each criterion in addition to a generic trace.
552
559
"""
553
560
# You could have copied the mark_dynamic behavior but I'm not convinced
554
561
# it's what you want
@@ -567,6 +574,12 @@ def mark_unbacked(
567
574
if not hasattr (t , "_dynamo_unbacked_indices" ):
568
575
t ._dynamo_unbacked_indices = set ()
569
576
577
+ if not hasattr (t , "_dynamo_hint_overrides" ):
578
+ t ._dynamo_hint_overrides = {}
579
+
580
+ if hint_override :
581
+ t ._dynamo_hint_overrides [index ] = hint_override
582
+
570
583
# FX tracers don't respect @forbid_in_graph and choke on the following error since it passes in proxies:
571
584
# TypeError: 'Attribute' object does not support item assignment
572
585
if isinstance (t ._specialize_on , dict ):
@@ -612,7 +625,10 @@ def mark_dynamic(
612
625
4) Attempts to trace this function will explicitly raise. As such, all calls to mark_dynamic must be made
613
626
before torch.compile.
614
627
615
- 5) If specialize_on is passed in, we will perform a single generic Dynamo trace followed by
628
+ 5) If hint_override is passed, the hint_override for the specified dimension will replace the provided value
629
+ from the first example input as the official size hint.
630
+
631
+ 6) If specialize_on is passed in, we will perform a single generic Dynamo trace followed by
616
632
multiple specialized compilations in addition to a single generic compilation. NB: For now we only support
617
633
per dimension specialization, or in other words we do not generate a cross product of specializations.
618
634
At runtime, we will dispatch to a specialized compiled region if the input matches the specialization criteria.
@@ -626,6 +642,7 @@ def mark_dynamic(
626
642
This approach results in one Dynamo trace and two backend compilations. When the input dimension equals 8 or 16
627
643
at runtime, execution will be directed to the specialized compiled region. Performance measurements indicate
628
644
2-8x speedups depending on the specific specialization and model architecture.
645
+
629
646
"""
630
647
if is_traceable_wrapper_subclass (t ):
631
648
# default behavior: mirror mark_dynamic() on all inner tensors with same dim as t
0 commit comments