Skip to content

Conversation

@zewenli98
Copy link
Collaborator

@zewenli98 zewenli98 commented Jan 8, 2026

Description

Added a lowering pass to replace SymInt with aten.sym_size. Users now can explicitly use dynamic dims in the models. e.g.:

import logging

import torch
import torch.nn as nn
import torch_tensorrt

logging.basicConfig(level=logging.DEBUG)

torch.manual_seed(0)


class ExpandReshapeModel(nn.Module):
    def __init__(self, embed_dim: int):
        super().__init__()
        self.cls_token = nn.Parameter(torch.randn(1, 1, embed_dim))
        self.embed_dim = embed_dim
        self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3)

    def forward(self, x: torch.Tensor):
        batch_size = x.shape[0]  # dynamic dim
        cls_token = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat([cls_token, x], dim=1)
        x = self.qkv_proj(x)
        reshaped_qkv = x.reshape(batch_size, x.size(1), 3, 12, -1)
        return reshaped_qkv


model = ExpandReshapeModel(embed_dim=768).cuda().eval()
x = torch.randn(4, 196, 768).cuda()
torch._dynamo.mark_dynamic(x, index=0, min=2, max=32)
trt_module = torch.compile(model, backend="tensorrt")
out = trt_module(x)
print(out)

Fixes #3981

Type of change

  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)

Checklist:

  • My code follows the style guidelines of this project (You can use the linters)
  • I have performed a self-review of my own code
  • I have commented my code, particularly in hard-to-understand areas and hacks
  • I have made corresponding changes to the documentation
  • I have added tests to verify my fix or my feature
  • New and existing unit tests pass locally with my changes
  • I have added the relevant labels to my PR in so that relevant reviewers are notified

@zewenli98 zewenli98 requested a review from narendasan January 8, 2026 18:57
@zewenli98 zewenli98 self-assigned this Jan 8, 2026
@meta-cla meta-cla bot added the cla signed label Jan 8, 2026
@github-actions github-actions bot added component: lowering Issues re: The lowering / preprocessing passes component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Jan 8, 2026
@github-actions github-actions bot requested a review from cehongwang January 8, 2026 18:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: lowering Issues re: The lowering / preprocessing passes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

🐛 [Bug] SymInt placeholder nodes not fully removed in remove_sym_nodes pass, causing issues with TensorRT lowering

2 participants