4040log = logging .getLogger (__name__ )
4141
4242
43- class _DimHint (Enum ):
43+ class _DimHintType (Enum ):
4444 """
4545 Enum for dynamic shape hints.
4646 - AUTO means automatic inference of shape (static or dynamic).
@@ -53,6 +53,23 @@ class _DimHint(Enum):
5353 DYNAMIC = auto ()
5454
5555
56+ @dataclasses .dataclass
57+ class _DimHint :
58+ type : _DimHintType
59+
60+ @staticmethod
61+ def AUTO ():
62+ return _DimHint (_DimHintType .AUTO )
63+
64+ @staticmethod
65+ def DYNAMIC ():
66+ return _DimHint (_DimHintType .DYNAMIC )
67+
68+ @staticmethod
69+ def STATIC ():
70+ return _DimHint (_DimHintType .STATIC )
71+
72+
5673class _Dim (type ):
5774 """
5875 Metaclass for :func:`Dim` types.
@@ -206,7 +223,7 @@ def _derive(self, fn):
206223 )
207224
208225
209- def Dim (name : str , * , min : Optional [ int ] = None , max : Optional [ int ] = None ):
226+ class Dim (type ):
210227 """
211228 :func:`Dim` constructs a type analogous to a named symbolic integer with a range.
212229 It can be used to describe multiple possible values of a dynamic tensor dimension.
@@ -222,22 +239,24 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
222239 A type that can be used in dynamic shape specifications for tensors.
223240 """
224241
225- from torch .utils ._sympy .numbers import int_oo
226-
227- _min = 0 if min is None else min
228- _max = int_oo if max is None else max
229- assert _max > _min , f"Cannot create Dim with inconsistent min={ min } , max={ max } "
230- assert name .isidentifier (), f"Dim name must be a valid identifier, got { name } "
231- dim = _Dim (name , (int ,), {"min" : _min , "max" : _max })
232- dim .__module__ = getattr (
233- inspect .getmodule (inspect .stack ()[1 ][0 ]), "__name__" , "__main__"
234- )
235- return dim
242+ AUTO = _DimHint .AUTO ()
243+ DYNAMIC = _DimHint .DYNAMIC ()
244+ STATIC = _DimHint .STATIC ()
236245
246+ def __new__ (
247+ metacls , name : str , * , min : Optional [int ] = None , max : Optional [int ] = None
248+ ):
249+ from torch .utils ._sympy .numbers import int_oo
237250
238- Dim .AUTO = _DimHint .AUTO # type: ignore[attr-defined]
239- Dim .STATIC = _DimHint .STATIC # type: ignore[attr-defined]
240- Dim .DYNAMIC = _DimHint .DYNAMIC # type: ignore[attr-defined]
251+ _min = 0 if min is None else min
252+ _max = int_oo if max is None else max
253+ assert _max > _min , f"Cannot create Dim with inconsistent min={ min } , max={ max } "
254+ assert name .isidentifier (), f"Dim name must be a valid identifier, got { name } "
255+ dim = _Dim (name , (int ,), {"min" : _min , "max" : _max })
256+ dim .__module__ = getattr (
257+ inspect .getmodule (inspect .stack ()[1 ][0 ]), "__name__" , "__main__"
258+ )
259+ return dim
241260
242261
243262def dims (
@@ -249,7 +268,7 @@ def dims(
249268 Returns:
250269 A tuple of :func:`Dim` types.
251270 """
252- return tuple (Dim (name , min = min , max = max ) for name in names )
271+ return tuple (Dim (name , min = min , max = max ) for name in names ) # type: ignore[misc]
253272
254273
255274@dataclasses .dataclass
@@ -923,11 +942,11 @@ def _create_static_dim(tensor, i, value):
923942 constraint = to_constraint (dim , tensor , i )
924943 symbols [dim .__name__ ].append (constraint )
925944 elif isinstance (dim , _DimHint ):
926- if dim == _DimHint .AUTO :
945+ if dim . type == _DimHintType .AUTO :
927946 torch ._dynamo .maybe_mark_dynamic (tensor , i )
928- elif dim == _DimHint .STATIC :
947+ elif dim . type == _DimHintType .STATIC :
929948 torch ._dynamo .mark_static (tensor , i )
930- elif dim == _DimHint .DYNAMIC :
949+ elif dim . type == _DimHintType .DYNAMIC :
931950 torch ._dynamo .mark_dynamic (tensor , i )
932951 constraints .append (_RelaxedConstraint (id (tensor ), i ))
933952 elif dim is None :
@@ -940,11 +959,11 @@ def _create_static_dim(tensor, i, value):
940959 constraint = to_constraint (dim , tensor , i )
941960 symbols [dim .__name__ ].append (constraint )
942961 elif isinstance (dim , _DimHint ):
943- if dim == _DimHint .AUTO :
962+ if dim . type == _DimHintType .AUTO :
944963 torch ._dynamo .maybe_mark_dynamic (tensor , i )
945- elif dim == _DimHint .STATIC :
964+ elif dim . type == _DimHintType .STATIC :
946965 torch ._dynamo .mark_static (tensor , i )
947- elif dim == _DimHint .DYNAMIC :
966+ elif dim . type == _DimHintType .DYNAMIC :
948967 torch ._dynamo .mark_dynamic (tensor , i )
949968 constraints .append (_RelaxedConstraint (id (tensor ), i ))
950969 elif dim is None :
@@ -1063,7 +1082,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
10631082 expr = sympy .sympify (expr )
10641083 if isinstance (expr , sympy .Number ):
10651084 # static, integer
1066- shape_fixes [name ] = int (expr )
1085+ shape_fixes [name ] = int (expr ) # type: ignore[assignment]
10671086 else :
10681087 # relation or derived dim
10691088 shape_fixes [name ] = expr
0 commit comments