@@ -70,83 +70,118 @@ def STATIC():
7070 return _DimHint (_DimHintType .STATIC )
7171
7272
73- class _Dim ( type ) :
73+ class Dim :
7474 """
75- Metaclass for :func:`Dim` types.
75+ :func:`Dim` constructs a type analogous to a named symbolic integer with a range.
76+ It can be used to describe multiple possible values of a dynamic tensor dimension.
77+ Note that different dynamic dimensions of the same tensor, or of different tensors,
78+ can be described by the same type.
79+
80+ Args:
81+ name (str): Human-readable name for debugging.
82+ min (Optional[int]): Minimum possible value of given symbol (inclusive)
83+ max (Optional[int]): Maximum possible value of given symbol (inclusive)
84+
85+ Returns:
86+ A type that can be used in dynamic shape specifications for tensors.
7687 """
7788
78- @staticmethod
79- def readable (name , min_ , max_ ):
89+ AUTO = _DimHint .AUTO ()
90+ DYNAMIC = _DimHint .DYNAMIC ()
91+ STATIC = _DimHint .STATIC ()
92+
93+ def __init__ (
94+ self , name : str , * , min : Optional [int ] = None , max : Optional [int ] = None
95+ ):
8096 from torch .utils ._sympy .numbers import int_oo
8197
82- if min_ == 2 :
83- min_ = None
84- if max_ == int_oo :
85- max_ = None
86- if min_ is None and max_ is None :
87- return f"Dim('{ name } ')"
88- if min_ is None :
89- return f"Dim('{ name } ', max={ max_ } )"
90- if max_ is None :
91- return f"Dim('{ name } ', min={ min_ } )"
92- return f"Dim('{ name } ', min={ min_ } , max={ max_ } )"
98+ _min = 0 if min is None else min
99+ _max = int_oo if max is None else max
100+ assert _max > _min , f"Cannot create Dim with inconsistent min={ min } , max={ max } "
101+ assert name .isidentifier (), f"Dim name must be a valid identifier, got { name } "
102+ self .__name__ = name
103+ self .min = _min
104+ self .max = _max
93105
94- def __add__ (cls , other ):
106+ def __add__ (self , other ) -> "Dim" :
95107 # e.g., dim + 1
96108 if type (other ) is not int :
97109 raise NotImplementedError (
98- f"Attempted to add { other } to { cls .__name__ } , where an integer was expected. "
110+ f"Attempted to add { other } to { self .__name__ } , where an integer was expected. "
99111 "(Only increasing linear operations with integer coefficients are supported.)"
100112 )
101- return cls ._derive (lambda x : x + other )
113+ return self ._derive (lambda x : x + other )
102114
103- def __radd__ (cls , other ):
104- return cls + other
115+ def __radd__ (self , other ) -> "Dim" :
116+ return self + other
105117
106- def __sub__ (cls , other ):
118+ def __sub__ (self , other ) -> "Dim" :
107119 # e.g., dim - 1
108120 if type (other ) is not int :
109121 raise NotImplementedError (
110- f"Attempted to subtract { other } from { cls .__name__ } , where an integer was expected. "
122+ f"Attempted to subtract { other } from { self .__name__ } , where an integer was expected. "
111123 "(Only increasing linear operations with integer coefficients are supported.)"
112124 )
113- return cls ._derive (lambda x : x - other )
125+ return self ._derive (lambda x : x - other )
114126
115- def __rsub__ (cls , other ):
127+ def __rsub__ (self , other ) -> "Dim" :
116128 raise NotImplementedError (
117- f"Attempted to negate { cls .__name__ } . "
129+ f"Attempted to negate { self .__name__ } . "
118130 "(Only increasing linear operations with integer coefficients are supported.)"
119131 )
120132
121- def __mul__ (cls , other ):
133+ def __mul__ (self , other ) -> "Dim" :
122134 # e.g., dim * 2
123135 if type (other ) is not int or other <= 0 :
124136 raise NotImplementedError (
125- f"Attempted to multiply { other } with { cls .__name__ } , where a positive integer was expected. "
137+ f"Attempted to multiply { other } with { self .__name__ } , where a positive integer was expected. "
126138 "(Only increasing linear operations with integer coefficients are supported.)"
127139 )
128- return cls ._derive (lambda x : x * other )
140+ return self ._derive (lambda x : x * other )
129141
130- def __rmul__ (cls , other ):
131- return cls * other
142+ def __rmul__ (self , other ) -> "Dim" :
143+ return self * other
132144
133- def _derived_name (cls , fn ):
145+ def _derived_name (self , fn ) -> str :
134146 from sympy import sympify
135147
136- return str (fn (sympify (cls .__name__ )))
148+ return str (fn (sympify (self .__name__ )))
137149
138- def _derive (cls , fn ):
139- return _DerivedDim (cls ._derived_name (fn ), ( int ,), { "root" : cls , "fn" : fn } )
150+ def _derive (self , fn ) -> "Dim" :
151+ return _DerivedDim (self ._derived_name (fn ), self , fn )
140152
153+ @staticmethod
154+ def readable (name : str , min_ : int , max_ : int ) -> str :
155+ from torch .utils ._sympy .numbers import int_oo
141156
142- class _StaticDim (_Dim ):
157+ if min_ == 2 :
158+ min_ = None # type: ignore[assignment]
159+ if max_ == int_oo :
160+ max_ = None # type: ignore[assignment]
161+ if min_ is None and max_ is None :
162+ return f"Dim('{ name } ')"
163+ if min_ is None :
164+ return f"Dim('{ name } ', max={ max_ } )"
165+ if max_ is None :
166+ return f"Dim('{ name } ', min={ min_ } )"
167+ return f"Dim('{ name } ', min={ min_ } , max={ max_ } )"
168+
169+ def __repr__ (self ):
170+ return Dim .readable (self .__name__ , self .min , self .max )
171+
172+
173+ class _StaticDim (Dim ):
143174 """
144- Meta class for static :func:`Dim` types.
175+ Class for static :func:`Dim` types.
145176
146177 This class is only for setting and checking static dim constraints,
147178 and the user should never interact with it.
148179 """
149180
181+ def __init__ (self , value : int ):
182+ self .__name__ = str (value )
183+ self .value = value
184+
150185 @property
151186 def min (self ):
152187 return self .value # type: ignore[attr-defined]
@@ -156,9 +191,9 @@ def max(self):
156191 return self .value # type: ignore[attr-defined]
157192
158193
159- class _DerivedDim (_Dim ):
194+ class _DerivedDim (Dim ):
160195 """
161- Metaclass for derived :func:`Dim` types.
196+ Class for derived :func:`Dim` types.
162197
163198 Currently we only support increasing linear expressions with integer coefficients.
164199 In other words, a derived Dim can always be written in the form Ax + B, where
@@ -172,6 +207,11 @@ class _DerivedDim(_Dim):
172207 The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
173208 """
174209
210+ def __init__ (self , name : str , root : Dim , fn : Callable ):
211+ self .__name__ = name
212+ self .root = root
213+ self .fn = fn
214+
175215 @property
176216 def min (self ):
177217 # assume that self.fn is an increasing function
@@ -218,50 +258,17 @@ def _derive(self, fn):
218258 # As a consequence, roots are always regular Dims (i.e., not derived Dims).
219259 return _DerivedDim (
220260 self ._derived_name (fn ),
221- ( int ,) ,
222- { "root" : self . root , "fn" : lambda x : fn (self .fn (x ))}, # type: ignore[attr-defined]
261+ self . root ,
262+ lambda x : fn (self .fn (x )),
223263 )
224264
225-
226- class Dim (type ):
227- """
228- :func:`Dim` constructs a type analogous to a named symbolic integer with a range.
229- It can be used to describe multiple possible values of a dynamic tensor dimension.
230- Note that different dynamic dimensions of the same tensor, or of different tensors,
231- can be described by the same type.
232-
233- Args:
234- name (str): Human-readable name for debugging.
235- min (Optional[int]): Minimum possible value of given symbol (inclusive)
236- max (Optional[int]): Maximum possible value of given symbol (inclusive)
237-
238- Returns:
239- A type that can be used in dynamic shape specifications for tensors.
240- """
241-
242- AUTO = _DimHint .AUTO ()
243- DYNAMIC = _DimHint .DYNAMIC ()
244- STATIC = _DimHint .STATIC ()
245-
246- def __new__ (
247- metacls , name : str , * , min : Optional [int ] = None , max : Optional [int ] = None
248- ):
249- from torch .utils ._sympy .numbers import int_oo
250-
251- _min = 0 if min is None else min
252- _max = int_oo if max is None else max
253- assert _max > _min , f"Cannot create Dim with inconsistent min={ min } , max={ max } "
254- assert name .isidentifier (), f"Dim name must be a valid identifier, got { name } "
255- dim = _Dim (name , (int ,), {"min" : _min , "max" : _max })
256- dim .__module__ = getattr (
257- inspect .getmodule (inspect .stack ()[1 ][0 ]), "__name__" , "__main__"
258- )
259- return dim
265+ def __repr__ (self ):
266+ return self .__name__
260267
261268
262269def dims (
263270 * names : str , min : Optional [int ] = None , max : Optional [int ] = None
264- ) -> tuple [_Dim , ...]:
271+ ) -> tuple [Dim , ...]:
265272 """
266273 Util to create multiple :func:`Dim` types.
267274
@@ -722,8 +729,8 @@ def check_same_bounds(dim):
722729 if dim .__name__ in bounds :
723730 min_ , max_ = bounds [dim .__name__ ]
724731 if dim .min != min_ or dim .max != max_ :
725- this_ = _Dim .readable (dim .__name__ , min_ , max_ )
726- that_ = _Dim .readable (dim .__name__ , dim .min , dim .max )
732+ this_ = Dim .readable (dim .__name__ , min_ , max_ )
733+ that_ = Dim .readable (dim .__name__ , dim .min , dim .max )
727734 raise UserError (
728735 UserErrorType .INVALID_INPUT ,
729736 f"Found different definitions { this_ } and { that_ } "
@@ -735,7 +742,7 @@ def check_same_bounds(dim):
735742 def check_symbols (path , tensor , shape ):
736743 if isinstance (shape , dict ):
737744 for i , dim in shape .items ():
738- if isinstance (dim , _Dim ):
745+ if isinstance (dim , Dim ):
739746 check_same_bounds (dim )
740747 elif dim is None :
741748 _warn_on_None_dynamic_shape_dimension ()
@@ -750,7 +757,7 @@ def check_symbols(path, tensor, shape):
750757 )
751758 elif isinstance (shape , (tuple , list )):
752759 for i , dim in enumerate (shape ):
753- if isinstance (dim , _Dim ):
760+ if isinstance (dim , Dim ):
754761 check_same_bounds (dim )
755762 elif dim is None :
756763 _warn_on_None_dynamic_shape_dimension ()
@@ -911,7 +918,7 @@ def root_value():
911918 ),
912919 )
913920 else :
914- assert isinstance (dim , _Dim )
921+ assert isinstance (dim , Dim )
915922 constraint = _Constraint ( # type: ignore[assignment]
916923 id (tensor ),
917924 i ,
@@ -924,7 +931,7 @@ def root_value():
924931
925932 def update_symbols (path , tensor , shape ):
926933 def _create_static_dim (tensor , i , value ):
927- return _StaticDim (str ( value ), ( int ,), { "value" : value } )
934+ return _StaticDim (value )
928935
929936 # clean out decorators from user side, or previous export call
930937 # we also delete these attributes in non_strict_utils.py/make_constraints()
@@ -936,7 +943,7 @@ def _create_static_dim(tensor, i, value):
936943
937944 if isinstance (shape , dict ):
938945 for i , dim in shape .items ():
939- if isinstance (dim , (int , _Dim )):
946+ if isinstance (dim , (int , Dim )):
940947 if isinstance (dim , int ):
941948 dim = _create_static_dim (tensor , i , dim )
942949 constraint = to_constraint (dim , tensor , i )
@@ -953,7 +960,7 @@ def _create_static_dim(tensor, i, value):
953960 torch ._dynamo .mark_static (tensor , i )
954961 elif isinstance (shape , (tuple , list )):
955962 for i , dim in enumerate (shape ):
956- if isinstance (dim , (int , _Dim )):
963+ if isinstance (dim , (int , Dim )):
957964 if isinstance (dim , int ):
958965 dim = _create_static_dim (tensor , i , dim )
959966 constraint = to_constraint (dim , tensor , i )
@@ -1002,14 +1009,14 @@ def _get_dim_name_mapping(
10021009 name_to_dim = {}
10031010 for dim in tree_flatten (
10041011 dynamic_shapes ,
1005- is_leaf = lambda x : isinstance (x , _Dim ),
1012+ is_leaf = lambda x : isinstance (x , Dim ),
10061013 )[0 ]:
10071014 if dim is None :
10081015 # NOTE: this must denote a non-Tensor or automatic at this point.
10091016 continue
10101017 if isinstance (dim , int ):
10111018 continue
1012- elif isinstance (dim , _Dim ):
1019+ elif isinstance (dim , Dim ):
10131020 name_to_dim [dim .__name__ ] = dim
10141021 if isinstance (dim , _DerivedDim ):
10151022 name_to_dim [dim .root .__name__ ] = dim .root # type: ignore[attr-defined]
@@ -1092,7 +1099,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
10921099 # track derived dim roots
10931100 roots : set [str ] = set ()
10941101 for k , c in shape_fixes .items ():
1095- assert isinstance (c , (int , _Dim , _DerivedDim , sympy .Expr ))
1102+ assert isinstance (c , (int , Dim , _DerivedDim , sympy .Expr ))
10961103 if isinstance (c , sympy .Expr ): # check dim/derived dim expression
10971104 assert _is_supported_equivalence (c )
10981105 shape_fixes [k ] = c
0 commit comments