Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2026,6 +2026,7 @@ def aten_ops_sub(
)


@dynamo_tensorrt_converter(operator.truediv, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.div.Tensor_mode, supports_dynamic_shapes=True)
@dynamo_tensorrt_converter(torch.ops.aten.div.Scalar, supports_dynamic_shapes=True)
Expand Down
9 changes: 9 additions & 0 deletions py/torch_tensorrt/dynamo/partitioning/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,15 @@ def construct_submodule_inputs(module: torch.fx.GraphModule) -> Sequence[Input]:
is_shape_tensor=True,
)
)
elif isinstance(input_meta, torch.SymFloat):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So it'll just be a 0D tensor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm registering it's shape as (1,) so 1D

torchtrt_inputs.append(
get_input(
[1],
torch.float32,
name=input.name,
is_shape_tensor=False, # Only SymInt inputs are treated as shape tensors
)
)
else:
raise ValueError(
f"The meta val for input node {input.target} is of type : {type(input_meta)}. Supported types: torch.Tensor|FakeTensor|torch.SymInt"
Expand Down
5 changes: 5 additions & 0 deletions py/torch_tensorrt/dynamo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -455,6 +455,9 @@ def unwrap_tensor_shape(
tensor_shape.append(min_max_opt[mode])
else:
tensor_shape.append((min_max_opt["min"], min_max_opt["max"]))
elif isinstance(tensor, torch.SymFloat):
# SymFloats can be an input to graph sometimes. We register their shape as [1] to avoid errors.
tensor_shape.append(1)
elif isinstance(tensor, (torch.Tensor, FakeTensor)):
for dimension in tensor.shape:
tensor_shape.extend(unwrap_tensor_shape(dimension, mode=mode))
Expand All @@ -472,6 +475,8 @@ def unwrap_tensor_dtype(tensor: Union[torch.Tensor, FakeTensor, torch.SymInt]) -
return torch.tensor(tensor).dtype
elif isinstance(tensor, torch.SymInt):
return torch.int64
elif isinstance(tensor, torch.SymFloat):
return torch.float32
elif tensor is None:
# Case where we explicitly pass one of the inputs to be None (eg: FLUX.1-dev)
return None
Expand Down
100 changes: 100 additions & 0 deletions tests/py/dynamo/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import importlib
import platform
import unittest
from typing import Optional

import pytest
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt
from torch_tensorrt.dynamo.utils import (
COSINE_THRESHOLD,
Expand Down Expand Up @@ -420,6 +422,104 @@ def test_resnet18_half(ir):
torch._dynamo.reset()


@pytest.mark.unit
def test_cosmos_true_div(ir):
class CosmosLearnablePositionalEmbed(torch.nn.Module):
def __init__(
self,
hidden_size: int,
max_size: tuple[int, int, int],
patch_size: tuple[int, int, int],
eps: float = 1e-6,
) -> None:
super().__init__()

self.max_size = [size // patch for size, patch in zip(max_size, patch_size)]
self.patch_size = patch_size
self.eps = eps

self.pos_emb_t = nn.Parameter(torch.randn(self.max_size[0], hidden_size))
self.pos_emb_h = nn.Parameter(torch.randn(self.max_size[1], hidden_size))
self.pos_emb_w = nn.Parameter(torch.randn(self.max_size[2], hidden_size))

def forward(
self,
hidden_states: torch.Tensor,
num_ranks: Optional[int] = None,
rank_id: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, num_channels, num_frames, height, width = hidden_states.shape
pe_size = [
num_frames // self.patch_size[0],
height // self.patch_size[1],
width // self.patch_size[2],
]
if num_ranks is not None and rank_id is not None:
pe_size[0] = pe_size[0] * num_ranks

# Use expand() instead of repeat() - torch_tensorrt compatible
# expand() creates a view without copying data, better for dynamic shapes
emb_t = self.pos_emb_t[: pe_size[0]][None, :, None, None, :].expand(
batch_size, -1, pe_size[1], pe_size[2], -1
)
emb_h = self.pos_emb_h[: pe_size[1]][None, None, :, None, :].expand(
batch_size, pe_size[0], -1, pe_size[2], -1
)
emb_w = self.pos_emb_w[: pe_size[2]][None, None, None, :, :].expand(
batch_size, pe_size[0], pe_size[1], -1, -1
)
emb = emb_t + emb_h + emb_w
emb = emb.flatten(1, 3)

norm = torch.linalg.vector_norm(
emb, dim=-1, keepdim=True, dtype=torch.float32
)
alpha = (norm.numel() / emb.numel()) ** 0.5
# hidden_size = emb.shape[-1]
# alpha = (1.0 / hidden_size) ** 0.5
norm = torch.add(self.eps, norm, alpha=alpha)
return (emb / norm).type_as(hidden_states)

with torch.no_grad():
hidden_states = torch.randn(1, 16, 16, 88, 160).cuda()
model = CosmosLearnablePositionalEmbed(
hidden_size=4096,
max_size=(128, 240, 240),
patch_size=(1, 2, 2),
)
model.eval().cuda()
pyt_output = model(hidden_states)
num_latent_frames = torch.export.Dim("num_latent_frames", min=1, max=16)

ep = torch.export.export(
model,
args=(hidden_states,),
dynamic_shapes=({2: num_latent_frames},), # Make dimension 2 dynamic
strict=False,
)
trt_model = torchtrt.dynamo.compile(
ep,
inputs=(hidden_states,),
enabled_precisions={torch.bfloat16},
use_explicit_typing=False,
use_fp32_acc=False,
device="cuda:0",
disable_tf32=True,
use_python_runtime=True,
min_block_size=1,
)
trt_output = trt_model(hidden_states)

cos_sim = cosine_similarity(pyt_output, trt_output)
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"Cosmos Learnable Positional Embed TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# Clean up model env
torch._dynamo.reset()


@pytest.mark.unit
@unittest.skipIf(
torchtrt.ENABLED_FEATURES.tensorrt_rtx,
Expand Down
Loading