Skip to content

Commit f0a0622

Browse files
pianpwkamathewc
authored andcommitted
[export] refactor DimHints for type errors (pytorch#149424)
Differential Revision: D71414367 Pull Request resolved: pytorch#149424 Approved by: https://github.com/justinchuby, https://github.com/avikchaudhuri
1 parent bd5dde4 commit f0a0622

File tree

6 files changed

+59
-39
lines changed

6 files changed

+59
-39
lines changed

test/inductor/test_cutlass_backend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -736,7 +736,7 @@ def forward(self, x, w):
736736
model = MyModel()
737737
M, N, K = 128, 64, 64
738738
dynamic_shapes = {
739-
"x": {0: Dim.DYNAMIC}, # type: ignore[attr-defined]
739+
"x": {0: Dim.DYNAMIC},
740740
"w": None,
741741
}
742742

torch/_export/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1514,7 +1514,7 @@ def retrace_as_exported_program(
15141514
):
15151515
dynamic_shapes = _tree_map_with_path(
15161516
lambda path, x: (
1517-
[Dim.AUTO] * x.dim() if isinstance(x, torch.Tensor) else None # type: ignore[attr-defined]
1517+
[Dim.AUTO] * x.dim() if isinstance(x, torch.Tensor) else None
15181518
),
15191519
self.sample_args,
15201520
)

torch/_export/serde/dynamic_shapes.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ def _standardize_shapes(path, tensor, shape): # type: ignore[no-untyped-def]
137137
if not isinstance(tensor, torch.Tensor):
138138
return None
139139
if shape is None:
140-
return [Dim.STATIC] * len(tensor.shape) # type: ignore[attr-defined]
140+
return [Dim.STATIC] * len(tensor.shape)
141141

142142
out = []
143143
if isinstance(shape, dict):
@@ -158,7 +158,7 @@ def _track_dim_from_dims(
158158
if val is None or isinstance(val, int): # non-tensor input or static
159159
return val
160160
if isinstance(val, _DimHint): # store enum as string
161-
return val.__class__.__name__ + "." + val.name
161+
return val.__class__.__name__ + "." + val.type.name
162162

163163
assert isinstance(val, _Dim)
164164

@@ -290,9 +290,9 @@ def _load_dynamic_shapes(
290290
modulus, remainder = sympy.polys.polytools.div(expr, symbol)
291291
ddim = dim_cache[name]
292292
if modulus != 1:
293-
ddim = int(modulus) * ddim
293+
ddim = int(modulus) * ddim # type: ignore[assignment, operator]
294294
if remainder != 0:
295-
ddim = ddim + int(remainder)
295+
ddim = ddim + int(remainder) # type: ignore[assignment, operator]
296296
dim_cache[_expr] = ddim # cache derived dims
297297

298298
def deserialize_shape(
@@ -301,9 +301,11 @@ def deserialize_shape(
301301
if val is None or isinstance(val, int):
302302
return val
303303
elif val == "_DimHint.AUTO":
304-
return _DimHint.AUTO
304+
return _DimHint.AUTO()
305+
elif val == "_DimHint.DYNAMIC":
306+
return _DimHint.DYNAMIC()
305307
elif val == "_DimHint.STATIC":
306-
return _DimHint.STATIC
308+
return _DimHint.STATIC()
307309
if not isinstance(val, str):
308310
raise UserError(
309311
UserErrorType.INVALID_INPUT,
@@ -316,6 +318,6 @@ def deserialize_shape(
316318
"Expected dims in `spec['dynamic_shapes']` to be tracked in `spec['dims']`, "
317319
f"got {val} which is not in {dims.keys()}",
318320
)
319-
return dim_cache[val]
321+
return dim_cache[val] # type: ignore[return-value]
320322

321323
return tree_map(deserialize_shape, dynamic_shapes)

torch/export/dynamic_shapes.py

Lines changed: 43 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
log = logging.getLogger(__name__)
4141

4242

43-
class _DimHint(Enum):
43+
class _DimHintType(Enum):
4444
"""
4545
Enum for dynamic shape hints.
4646
- AUTO means automatic inference of shape (static or dynamic).
@@ -53,6 +53,23 @@ class _DimHint(Enum):
5353
DYNAMIC = auto()
5454

5555

56+
@dataclasses.dataclass
57+
class _DimHint:
58+
type: _DimHintType
59+
60+
@staticmethod
61+
def AUTO():
62+
return _DimHint(_DimHintType.AUTO)
63+
64+
@staticmethod
65+
def DYNAMIC():
66+
return _DimHint(_DimHintType.DYNAMIC)
67+
68+
@staticmethod
69+
def STATIC():
70+
return _DimHint(_DimHintType.STATIC)
71+
72+
5673
class _Dim(type):
5774
"""
5875
Metaclass for :func:`Dim` types.
@@ -206,7 +223,7 @@ def _derive(self, fn):
206223
)
207224

208225

209-
def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
226+
class Dim(type):
210227
"""
211228
:func:`Dim` constructs a type analogous to a named symbolic integer with a range.
212229
It can be used to describe multiple possible values of a dynamic tensor dimension.
@@ -222,22 +239,24 @@ def Dim(name: str, *, min: Optional[int] = None, max: Optional[int] = None):
222239
A type that can be used in dynamic shape specifications for tensors.
223240
"""
224241

225-
from torch.utils._sympy.numbers import int_oo
226-
227-
_min = 0 if min is None else min
228-
_max = int_oo if max is None else max
229-
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
230-
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
231-
dim = _Dim(name, (int,), {"min": _min, "max": _max})
232-
dim.__module__ = getattr(
233-
inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
234-
)
235-
return dim
242+
AUTO = _DimHint.AUTO()
243+
DYNAMIC = _DimHint.DYNAMIC()
244+
STATIC = _DimHint.STATIC()
236245

246+
def __new__(
247+
metacls, name: str, *, min: Optional[int] = None, max: Optional[int] = None
248+
):
249+
from torch.utils._sympy.numbers import int_oo
237250

238-
Dim.AUTO = _DimHint.AUTO # type: ignore[attr-defined]
239-
Dim.STATIC = _DimHint.STATIC # type: ignore[attr-defined]
240-
Dim.DYNAMIC = _DimHint.DYNAMIC # type: ignore[attr-defined]
251+
_min = 0 if min is None else min
252+
_max = int_oo if max is None else max
253+
assert _max > _min, f"Cannot create Dim with inconsistent min={min}, max={max}"
254+
assert name.isidentifier(), f"Dim name must be a valid identifier, got {name}"
255+
dim = _Dim(name, (int,), {"min": _min, "max": _max})
256+
dim.__module__ = getattr(
257+
inspect.getmodule(inspect.stack()[1][0]), "__name__", "__main__"
258+
)
259+
return dim
241260

242261

243262
def dims(
@@ -249,7 +268,7 @@ def dims(
249268
Returns:
250269
A tuple of :func:`Dim` types.
251270
"""
252-
return tuple(Dim(name, min=min, max=max) for name in names)
271+
return tuple(Dim(name, min=min, max=max) for name in names) # type: ignore[misc]
253272

254273

255274
@dataclasses.dataclass
@@ -923,11 +942,11 @@ def _create_static_dim(tensor, i, value):
923942
constraint = to_constraint(dim, tensor, i)
924943
symbols[dim.__name__].append(constraint)
925944
elif isinstance(dim, _DimHint):
926-
if dim == _DimHint.AUTO:
945+
if dim.type == _DimHintType.AUTO:
927946
torch._dynamo.maybe_mark_dynamic(tensor, i)
928-
elif dim == _DimHint.STATIC:
947+
elif dim.type == _DimHintType.STATIC:
929948
torch._dynamo.mark_static(tensor, i)
930-
elif dim == _DimHint.DYNAMIC:
949+
elif dim.type == _DimHintType.DYNAMIC:
931950
torch._dynamo.mark_dynamic(tensor, i)
932951
constraints.append(_RelaxedConstraint(id(tensor), i))
933952
elif dim is None:
@@ -940,11 +959,11 @@ def _create_static_dim(tensor, i, value):
940959
constraint = to_constraint(dim, tensor, i)
941960
symbols[dim.__name__].append(constraint)
942961
elif isinstance(dim, _DimHint):
943-
if dim == _DimHint.AUTO:
962+
if dim.type == _DimHintType.AUTO:
944963
torch._dynamo.maybe_mark_dynamic(tensor, i)
945-
elif dim == _DimHint.STATIC:
964+
elif dim.type == _DimHintType.STATIC:
946965
torch._dynamo.mark_static(tensor, i)
947-
elif dim == _DimHint.DYNAMIC:
966+
elif dim.type == _DimHintType.DYNAMIC:
948967
torch._dynamo.mark_dynamic(tensor, i)
949968
constraints.append(_RelaxedConstraint(id(tensor), i))
950969
elif dim is None:
@@ -1063,7 +1082,7 @@ def refine_dynamic_shapes_from_suggested_fixes(
10631082
expr = sympy.sympify(expr)
10641083
if isinstance(expr, sympy.Number):
10651084
# static, integer
1066-
shape_fixes[name] = int(expr)
1085+
shape_fixes[name] = int(expr) # type: ignore[assignment]
10671086
else:
10681087
# relation or derived dim
10691088
shape_fixes[name] = expr

torch/onnx/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -474,7 +474,7 @@ def _to_dynamic_shape(x):
474474
rank = len(x.shape)
475475
dynamic_shape = {}
476476
for i in range(rank):
477-
dynamic_shape[i] = torch.export.Dim.AUTO # type: ignore[attr-defined]
477+
dynamic_shape[i] = torch.export.Dim.AUTO
478478
return dynamic_shape
479479
else:
480480
return None

torch/onnx/_internal/exporter/_dynamic_shapes.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,14 @@ def from_dynamic_axes_to_dynamic_shapes(
6363
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
6464
)
6565
dynamic_shapes[input_name] = {
66-
k: torch.export.Dim.AUTO # type: ignore[attr-defined]
67-
for k, _ in axes.items()
66+
k: torch.export.Dim.AUTO for k, _ in axes.items()
6867
}
6968
elif isinstance(axes, list):
7069
if any(not isinstance(k, int) for k in axes):
7170
raise ValueError(
7271
"The axis in dynamic_axes must be in the form of: dict[int, str] or list[int]."
7372
)
74-
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes} # type: ignore[attr-defined]
73+
dynamic_shapes[input_name] = {k: torch.export.Dim.AUTO for k in axes}
7574
elif axes is None:
7675
dynamic_shapes[input_name] = None
7776
else:
@@ -203,15 +202,15 @@ def convert_str_to_export_dim(
203202
converted_axes_dict: dict[int, _Dim | _DimHint | None] = {}
204203
for axis, dim in axes.items():
205204
if isinstance(dim, str):
206-
converted_axes_dict[axis] = torch.export.Dim.AUTO # type: ignore[attr-defined]
205+
converted_axes_dict[axis] = torch.export.Dim.AUTO
207206
else:
208207
converted_axes_dict[axis] = dim
209208
dynamic_shapes_with_export_dim.append(converted_axes_dict)
210209
elif isinstance(axes, (list, tuple)):
211210
converted_axes_list: list[_Dim | _DimHint | None] = []
212211
for dim in axes:
213212
if isinstance(dim, str):
214-
converted_axes_list.append(torch.export.Dim.AUTO) # type: ignore[attr-defined]
213+
converted_axes_list.append(torch.export.Dim.AUTO)
215214
else:
216215
converted_axes_list.append(dim)
217216
dynamic_shapes_with_export_dim.append(converted_axes_list)

0 commit comments

Comments
 (0)