Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 6 additions & 8 deletions torchrec/ir/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from torch import nn
from torch.export import Dim, ShapesCollection
from torch.export.dynamic_shapes import _Dim as DIM
from torch.export.unflatten import InterpreterModule
from torch.fx import Node
from torchrec.ir.types import SerializerInterface
Expand Down Expand Up @@ -169,22 +168,21 @@ def decapsulate_ir_modules(
return module


def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> DIM:
def _get_dim(name: str, min: Optional[int] = None, max: Optional[int] = None) -> Dim:
"""
Returns a Dim object with the given name and min/max. If the name is not unique, it will append a suffix to the name.
"""
dim = f"{name}_{DYNAMIC_DIMS[name]}"
DYNAMIC_DIMS[name] += 1
# pyre-ignore[7]: Expected `DIM` but got `Dim`.
return Dim(dim, min=min, max=max)


def mark_dynamic_kjt(
kjt: KeyedJaggedTensor,
shapes_collection: Optional[ShapesCollection] = None,
variable_length: bool = False,
vlen: Optional[DIM] = None,
llen: Optional[DIM] = None,
vlen: Optional[Dim] = None,
llen: Optional[Dim] = None,
) -> ShapesCollection:
"""
Makes the given KJT dynamic. If it's not variable length, it will only have
Expand All @@ -203,9 +201,9 @@ def mark_dynamic_kjt(
kjt (KeyedJaggedTensor): The KJT to make dynamic.
shapes_collection (Optional[ShapesCollection]): The collection to update.
variable_length (bool): Whether the KJT is variable length.
vlen (Optional[DIM]): The dynamic length for the values. If it's None, it will use the default name "vlen".
llen (Optional[DIM]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen".
batch_size (Optional[DIM]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true.
vlen (Optional[Dim]): The dynamic length for the values. If it's None, it will use the default name "vlen".
llen (Optional[Dim]): The dynamic length for the lengths, it's only used when variable_length is true. If it's None, it will use the default name "llen".
batch_size (Optional[Dim]): The dynamic length for the batch_size, it's only used when variable_length and mark_batch_size are both true.
"""

def _has_dim(t: Optional[torch.Tensor]) -> bool:
Expand Down
Loading