Skip to content

Commit 67920bb

Browse files
pianpwkfacebook-github-bot
authored andcommitted
[export] refactor _Dim into Dim (#149891)
Summary: X-link: meta-pytorch/torchrec#2847 X-link: pytorch/executorch#9559 forward fix T218515233 Test Plan: test_export Differential Revision: D71769231
1 parent 6ae8eb8 commit 67920bb

File tree

6 files changed

+116
-110
lines changed

6 files changed

+116
-110
lines changed

test/export/test_export.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3835,7 +3835,7 @@ def forward(self, x):
38353835

38363836
dynamic_shapes = (
38373837
{"k": {"k": dim}},
3838-
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*_Dim.*.
3838+
) # ValueError: Node type mismatch; expected <class 'list'>, but got .*Dim.*.
38393839
with self.assertRaisesRegex(
38403840
torch._dynamo.exc.UserError,
38413841
re.escape(
@@ -12311,7 +12311,7 @@ def test_dynamic_shapes_serdes_user_errors(self):
1231112311

1231212312
self.assertExpectedInline(
1231312313
_load_dynamic_shapes(spec, from_dict=False),
12314-
"""[[<class 'torch._export.serde.dynamic_shapes.dx'>]]""",
12314+
"""[[Dim('dx', min=4, max=16)]]""",
1231512315
)
1231612316

1231712317
# check incorrect info in dims

test/onnx/exporter/test_small_models_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -450,7 +450,7 @@ def forward(
450450
)
451451

452452
dynamic_shapes = (
453-
{0: torch.export.Dim("dim_x", min=3)}, # _Dim
453+
{0: torch.export.Dim("dim_x", min=3)}, # Dim
454454
[("custom_name_axis_ys_0",), (torch.export.Dim.AUTO,)], # custom name
455455
{
456456
"a": {0: torch.export.Dim.AUTO},

torch/_export/serde/dynamic_shapes.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from torch.export.dynamic_shapes import (
77
_check_dynamic_shapes,
88
_DerivedDim,
9-
_Dim,
109
_DimHint,
1110
_tree_map_with_path,
1211
Dim,
@@ -19,7 +18,7 @@
1918
@dataclasses.dataclass
2019
class RootDim:
2120
"""
22-
This represents a _Dim object.
21+
This represents a Dim object.
2322
"""
2423

2524
min: int
@@ -150,7 +149,7 @@ def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
150149
return out
151150

152151
def _track_dim_from_dims(
153-
val: Union[None, int, _DimHint, _Dim]
152+
val: Union[None, int, _DimHint, Dim]
154153
) -> Union[None, int, str]:
155154
"""
156155
Tracks dims, ranges, derived dims from the standardized dynamic_shapes spec.
@@ -160,7 +159,7 @@ def _track_dim_from_dims(
160159
if isinstance(val, _DimHint): # store enum as string
161160
return val.__class__.__name__ + "." + val.type.name
162161

163-
assert isinstance(val, _Dim)
162+
assert isinstance(val, Dim)
164163

165164
# track root dim
166165
root = val.root if isinstance(val, _DerivedDim) else val # type: ignore[attr-defined]
@@ -297,7 +296,7 @@ def _load_dynamic_shapes(
297296

298297
def deserialize_shape(
299298
val: Union[None, int, str]
300-
) -> Union[None, int, _Dim, _DimHint]:
299+
) -> Union[None, int, Dim, _DimHint]:
301300
if val is None or isinstance(val, int):
302301
return val
303302
elif val == "_DimHint.AUTO":

torch/export/dynamic_shapes.py

Lines changed: 94 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -70,83 +70,118 @@ def STATIC():
7070
return _DimHint(_DimHintType.STATIC)
7171

7272

73-
class _Dim(type):
73+
class Dim:
7474
"""
75-
Metaclass for :func:`Dim` types.
75+
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
76+
It can be used to describe multiple possible values of a dynamic tensor dimension.
77+
Note that different dynamic dimensions of the same tensor, or of different tensors,
78+
can be described by the same type.
79+
80+
Args:
81+
name (str): Human-readable name for debugging.
82+
min (Optional[int]): Minimum possible value of given symbol (inclusive)
83+
max (Optional[int]): Maximum possible value of given symbol (inclusive)
84+
85+
Returns:
86+
A type that can be used in dynamic shape specifications for tensors.
7687
"""
7788

78-
@staticmethod
79-
def readable(name, min_, max_):
89+
AUTO = _DimHint.AUTO()
90+
DYNAMIC = _DimHint.DYNAMIC()
91+
STATIC = _DimHint.STATIC()
92+
93+
def __init__(
94+
self, name: str, *, min: Optional[int] = None, max: Optional[int] = None
95+
):
8096
from torch.utils._sympy.numbers import int_oo
8197

82-
if min_ == 2:
83-
min_ = None
84-
if max_ == int_oo:
85-
max_ = None
86-
if min_ is None and max_ is None:
87-
return f"Dim('{name}')"
88-
if min_ is None:
89-
return f"Dim('{name}', max={max_})"
90-
if max_ is None:
91-
return f"Dim('{name}', min={min_})"
92-
return f"Dim('{name}', min={min_}, max={max_})"
98+
_min = 0 if min is None else min
99+
_max = int_oo if max is None else max
100+
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
101+
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
102+
self.__name__ = name
103+
self.min = _min
104+
self.max = _max
93105

94-
def __add__(cls, other):
106+
def __add__(self, other) -> "Dim":
95107
# e.g., dim + 1
96108
if type(other) is not int:
97109
raise NotImplementedError(
98-
f"Attempted to add {other} to {cls.__name__}, where an integer was expected. "
110+
f"Attempted to add {other} to {self.__name__}, where an integer was expected. "
99111
"(Only increasing linear operations with integer coefficients are supported.)"
100112
)
101-
return cls._derive(lambda x: x + other)
113+
return self._derive(lambda x: x + other)
102114

103-
def __radd__(cls, other):
104-
return cls + other
115+
def __radd__(self, other) -> "Dim":
116+
return self + other
105117

106-
def __sub__(cls, other):
118+
def __sub__(self, other) -> "Dim":
107119
# e.g., dim - 1
108120
if type(other) is not int:
109121
raise NotImplementedError(
110-
f"Attempted to subtract {other} from {cls.__name__}, where an integer was expected. "
122+
f"Attempted to subtract {other} from {self.__name__}, where an integer was expected. "
111123
"(Only increasing linear operations with integer coefficients are supported.)"
112124
)
113-
return cls._derive(lambda x: x - other)
125+
return self._derive(lambda x: x - other)
114126

115-
def __rsub__(cls, other):
127+
def __rsub__(self, other) -> "Dim":
116128
raise NotImplementedError(
117-
f"Attempted to negate {cls.__name__}. "
129+
f"Attempted to negate {self.__name__}. "
118130
"(Only increasing linear operations with integer coefficients are supported.)"
119131
)
120132

121-
def __mul__(cls, other):
133+
def __mul__(self, other) -> "Dim":
122134
# e.g., dim * 2
123135
if type(other) is not int or other <= 0:
124136
raise NotImplementedError(
125-
f"Attempted to multiply {other} with {cls.__name__}, where a positive integer was expected. "
137+
f"Attempted to multiply {other} with {self.__name__}, where a positive integer was expected. "
126138
"(Only increasing linear operations with integer coefficients are supported.)"
127139
)
128-
return cls._derive(lambda x: x * other)
140+
return self._derive(lambda x: x * other)
129141

130-
def __rmul__(cls, other):
131-
return cls * other
142+
def __rmul__(self, other) -> "Dim":
143+
return self * other
132144

133-
def _derived_name(cls, fn):
145+
def _derived_name(self, fn) -> str:
134146
from sympy import sympify
135147

136-
return str(fn(sympify(cls.__name__)))
148+
return str(fn(sympify(self.__name__)))
137149

138-
def _derive(cls, fn):
139-
return _DerivedDim(cls._derived_name(fn), (int,), {"root": cls, "fn": fn})
150+
def _derive(self, fn) -> "Dim":
151+
return _DerivedDim(self._derived_name(fn), self, fn)
140152

153+
@staticmethod
154+
def readable(name: str, min_: int, max_: int) -> str:
155+
from torch.utils._sympy.numbers import int_oo
141156

142-
class _StaticDim(_Dim):
157+
if min_ == 2:
158+
min_ = None # type: ignore[assignment]
159+
if max_ == int_oo:
160+
max_ = None # type: ignore[assignment]
161+
if min_ is None and max_ is None:
162+
return f"Dim('{name}')"
163+
if min_ is None:
164+
return f"Dim('{name}', max={max_})"
165+
if max_ is None:
166+
return f"Dim('{name}', min={min_})"
167+
return f"Dim('{name}', min={min_}, max={max_})"
168+
169+
def __repr__(self):
170+
return Dim.readable(self.__name__, self.min, self.max)
171+
172+
173+
class _StaticDim(Dim):
143174
"""
144-
Meta class for static :func:`Dim` types.
175+
Class for static :func:`Dim` types.
145176
146177
This class is only for setting and checking static dim constraints,
147178
and the user should never interact with it.
148179
"""
149180

181+
def __init__(self, value: int):
182+
self.__name__ = str(value)
183+
self.value = value
184+
150185
@property
151186
def min(self):
152187
return self.value # type: ignore[attr-defined]
@@ -156,9 +191,9 @@ def max(self):
156191
return self.value # type: ignore[attr-defined]
157192

158193

159-
class _DerivedDim(_Dim):
194+
class _DerivedDim(Dim):
160195
"""
161-
Metaclass for derived :func:`Dim` types.
196+
Class for derived :func:`Dim` types.
162197
163198
Currently we only support increasing linear expressions with integer coefficients.
164199
In other words, a derived Dim can always be written in the form Ax + B, where
@@ -172,6 +207,11 @@ class _DerivedDim(_Dim):
172207
The range of a derived Dim is computed by mapping `fn` over the range of its `root`.
173208
"""
174209

210+
def __init__(self, name: str, root: Dim, fn: Callable):
211+
self.__name__ = name
212+
self.root = root
213+
self.fn = fn
214+
175215
@property
176216
def min(self):
177217
# assume that self.fn is an increasing function
@@ -218,50 +258,17 @@ def _derive(self, fn):
218258
# As a consequence, roots are always regular Dims (i.e., not derived Dims).
219259
return _DerivedDim(
220260
self._derived_name(fn),
221-
(int,),
222-
{"root": self.root, "fn": lambda x: fn(self.fn(x))}, # type: ignore[attr-defined]
261+
self.root,
262+
lambda x: fn(self.fn(x)),
223263
)
224264

225-
226-
class Dim(type):
227-
"""
228-
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
229-
It can be used to describe multiple possible values of a dynamic tensor dimension.
230-
Note that different dynamic dimensions of the same tensor, or of different tensors,
231-
can be described by the same type.
232-
233-
Args:
234-
name (str): Human-readable name for debugging.
235-
min (Optional[int]): Minimum possible value of given symbol (inclusive)
236-
max (Optional[int]): Maximum possible value of given symbol (inclusive)
237-
238-
Returns:
239-
A type that can be used in dynamic shape specifications for tensors.
240-
"""
241-
242-
AUTO = _DimHint.AUTO()
243-
DYNAMIC = _DimHint.DYNAMIC()
244-
STATIC = _DimHint.STATIC()
245-
246-
def __new__(
247-
metacls, name: str, *, min: Optional[int] = None, max: Optional[int] = None
248-
):
249-
from torch.utils._sympy.numbers import int_oo
250-
251-
_min = 0 if min is None else min
252-
_max = int_oo if max is None else max
253-
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
254-
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
255-
dim = _Dim(name, (int,), {"min": _min, "max": _max})
256-
dim.__module__ = getattr(
257-
inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
258-
)
259-
return dim
265+
def __repr__(self):
266+
return self.__name__
260267

261268

262269
def dims(
263270
*names: str, min: Optional[int] = None, max: Optional[int] = None
264-
) -> tuple[_Dim, ...]:
271+
) -> tuple[Dim, ...]:
265272
"""
266273
Util to create multiple :func:`Dim` types.
267274
@@ -722,8 +729,8 @@ def check_same_bounds(dim):
722729
if dim.__name__ in bounds:
723730
min_, max_ = bounds[dim.__name__]
724731
if dim.min != min_ or dim.max != max_:
725-
this_ = _Dim.readable(dim.__name__, min_, max_)
726-
that_ = _Dim.readable(dim.__name__, dim.min, dim.max)
732+
this_ = Dim.readable(dim.__name__, min_, max_)
733+
that_ = Dim.readable(dim.__name__, dim.min, dim.max)
727734
raise UserError(
728735
UserErrorType.INVALID_INPUT,
729736
f"Found different definitions {this_} and {that_} "
@@ -735,7 +742,7 @@ def check_same_bounds(dim):
735742
def check_symbols(path, tensor, shape):
736743
if isinstance(shape, dict):
737744
for i, dim in shape.items():
738-
if isinstance(dim, _Dim):
745+
if isinstance(dim, Dim):
739746
check_same_bounds(dim)
740747
elif dim is None:
741748
_warn_on_None_dynamic_shape_dimension()
@@ -750,7 +757,7 @@ def check_symbols(path, tensor, shape):
750757
)
751758
elif isinstance(shape, (tuple, list)):
752759
for i, dim in enumerate(shape):
753-
if isinstance(dim, _Dim):
760+
if isinstance(dim, Dim):
754761
check_same_bounds(dim)
755762
elif dim is None:
756763
_warn_on_None_dynamic_shape_dimension()
@@ -911,7 +918,7 @@ def root_value():
911918
),
912919
)
913920
else:
914-
assert isinstance(dim, _Dim)
921+
assert isinstance(dim, Dim)
915922
constraint = _Constraint( # type: ignore[assignment]
916923
id(tensor),
917924
i,
@@ -924,7 +931,7 @@ def root_value():
924931

925932
def update_symbols(path, tensor, shape):
926933
def _create_static_dim(tensor, i, value):
927-
return _StaticDim(str(value), (int,), {"value": value})
934+
return _StaticDim(value)
928935

929936
# clean out decorators from user side, or previous export call
930937
# we also delete these attributes in non_strict_utils.py/make_constraints()
@@ -936,7 +943,7 @@ def _create_static_dim(tensor, i, value):
936943

937944
if isinstance(shape, dict):
938945
for i, dim in shape.items():
939-
if isinstance(dim, (int, _Dim)):
946+
if isinstance(dim, (int, Dim)):
940947
if isinstance(dim, int):
941948
dim = _create_static_dim(tensor, i, dim)
942949
constraint = to_constraint(dim, tensor, i)
@@ -953,7 +960,7 @@ def _create_static_dim(tensor, i, value):
953960
torch._dynamo.mark_static(tensor, i)
954961
elif isinstance(shape, (tuple, list)):
955962
for i, dim in enumerate(shape):
956-
if isinstance(dim, (int, _Dim)):
963+
if isinstance(dim, (int, Dim)):
957964
if isinstance(dim, int):
958965
dim = _create_static_dim(tensor, i, dim)
959966
constraint = to_constraint(dim, tensor, i)
@@ -1002,14 +1009,14 @@ def _get_dim_name_mapping(
10021009
name_to_dim = {}
10031010
for dim in tree_flatten(
10041011
dynamic_shapes,
1005-
is_leaf=lambda x: isinstance(x, _Dim),
1012+
is_leaf=lambda x: isinstance(x, Dim),
10061013
)[0]:
10071014
if dim is None:
10081015
# NOTE: this must denote a non-Tensor or automatic at this point.
10091016
continue
10101017
if isinstance(dim, int):
10111018
continue
1012-
elif isinstance(dim, _Dim):
1019+
elif isinstance(dim, Dim):
10131020
name_to_dim[dim.__name__] = dim
10141021
if isinstance(dim, _DerivedDim):
10151022
name_to_dim[dim.root.__name__] = dim.root # type: ignore[attr-defined]
@@ -1092,7 +1099,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
10921099
# track derived dim roots
10931100
roots: set[str] = set()
10941101
for k, c in shape_fixes.items():
1095-
assert isinstance(c, (int, _Dim, _DerivedDim, sympy.Expr))
1102+
assert isinstance(c, (int, Dim, _DerivedDim, sympy.Expr))
10961103
if isinstance(c, sympy.Expr): # check dim/derived dim expression
10971104
assert _is_supported_equivalence(c)
10981105
shape_fixes[k] = c

0 commit comments

Comments
 (0)