diff --git a/torchrec/ir/utils.py b/torchrec/ir/utils.py index 990f16ff1..fca0dd621 100644 --- a/torchrec/ir/utils.py +++ b/torchrec/ir/utils.py @@ -18,7 +18,6 @@ from torch import nn from torch.export import Dim, ShapesCollection -from torch.export.dynamic_shapes import _Dim as DIM from torch.export.unflatten import InterpreterModule from torch.fx import Node from torchrec.ir.types import SerializerInterface @@ -169,13 +168,12 @@ def decapsulate_ir_modules( return module -def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM: +def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> Dim: """ Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name. """ dim = f"{name}_{DYNAMIC_DIMS[name]}" DYNAMIC_DIMS[name] += 1 - # pyre-ignore[7]: Expected `DIM` but got `Dim`. return Dim(dim, min=min, max=max) @@ -183,8 +181,8 @@ def mark_dynamic_kjt( kjt: KeyedJaggedTensor, shapes_collection: Optional[ShapesCollection] = None, variable_length: bool = False, - vlen: Optional[DIM] = None, - llen: Optional[DIM] = None, + vlen: Optional[Dim] = None, + llen: Optional[Dim] = None, ) -> ShapesCollection: """ Makes the given KJT dynamic. If it's not variable length, it will only have @@ -203,9 +201,9 @@ def mark_dynamic_kjt( kjt (KeyedJaggedTensor): The KJT to make dynamic. shapes_collection (Optional[ShapesCollection]): The collection to update. variable_length (bool): Whether the KJT is variable length. - vlen (Optional[DIM]): The dynamic length for the values. If it's None, it will use the default name "vlen". - llen (Optional[DIM]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen". - batch_size (Optional[DIM]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true. + vlen (Optional[Dim]): The dynamic length for the values. If it's None, it will use the default name "vlen". + llen (Optional[Dim]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen". + batch_size (Optional[Dim]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true. """ def _has_dim(t: Optional[torch.Tensor]) -> bool: