Skip to content

Commit fc4da98

Browse files
pianpwkfacebook-github-bot
authored andcommitted
refactor _Dim into Dim (#2847)
Summary: X-link: pytorch/pytorch#149891 X-link: pytorch/executorch#9559 forward fix T218515233 Differential Revision: D71769231
1 parent 7652c5d commit fc4da98

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

torchrec/ir/utils.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
from torch import nn
2020
from torch.export import Dim, ShapesCollection
21-
from torch.export.dynamic_shapes import _Dim as DIM
2221
from torch.export.unflatten import InterpreterModule
2322
from torch.fx import Node
2423
from torchrec.ir.types import SerializerInterface
@@ -169,22 +168,21 @@ def decapsulate_ir_modules(
169168
return module
170169

171170

172-
def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM:
171+
def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> Dim:
173172
"""
174173
Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name.
175174
"""
176175
dim = f"{name}_{DYNAMIC_DIMS[name]}"
177176
DYNAMIC_DIMS[name] += 1
178-
# pyre-ignore[7]: Expected `DIM` but got `Dim`.
179177
return Dim(dim, min=min, max=max)
180178

181179

182180
def mark_dynamic_kjt(
183181
kjt: KeyedJaggedTensor,
184182
shapes_collection: Optional[ShapesCollection] = None,
185183
variable_length: bool = False,
186-
vlen: Optional[DIM] = None,
187-
llen: Optional[DIM] = None,
184+
vlen: Optional[Dim] = None,
185+
llen: Optional[Dim] = None,
188186
) -> ShapesCollection:
189187
"""
190188
Makes the given KJT dynamic. If it's not variable length, it will only have
@@ -203,9 +201,9 @@ def mark_dynamic_kjt(
203201
kjt (KeyedJaggedTensor): The KJT to make dynamic.
204202
shapes_collection (Optional[ShapesCollection]): The collection to update.
205203
variable_length (bool): Whether the KJT is variable length.
206-
vlen (Optional[DIM]): The dynamic length for the values. If it's None, it will use the default name "vlen".
207-
llen (Optional[DIM]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen".
208-
batch_size (Optional[DIM]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true.
204+
vlen (Optional[Dim]): The dynamic length for the values. If it's None, it will use the default name "vlen".
205+
llen (Optional[Dim]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen".
206+
batch_size (Optional[Dim]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true.
209207
"""
210208

211209
def _has_dim(t: Optional[torch.Tensor]) -> bool:

0 commit comments

Comments
 (0)