Skip to content

πŸ› [Bug] index_put fails with dynamic shapeΒ #3806

@cehongwang

Description

@cehongwang
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt

class Model(nn.Module):


    def forward(self, x, y, z, a, b):
        x.index_add_(0, y, z)
        # x = x.index_add_(0, a, b)
        return x

model = Model().cuda()
inputs = [torch.randn((12, 2048)).half().cuda(), 
        torch.randint(0, 12, (5, )).cuda(), torch.randn((5, 2048)).half().cuda(), 
        torch.randint(0, 12, (3, )).cuda(), torch.randn((3, 2048)).half().cuda()]
torch_output = model.cuda().forward(*inputs)
seq_len1 = torch.export.Dim("seq_len1", min=1, max=128)
seq_len2 = torch.export.Dim("seq_len2", min=1, max=128)
seq_len3 = torch.export.Dim("seq_len3", min=1, max=128)


ep = torch.export.export(model, tuple(inputs), dynamic_shapes=({0: seq_len1}, {0: seq_len2}, {0: seq_len2}, {0: seq_len3}, {0: seq_len3}))
with torchtrt.dynamo.Debugger(log_level="debug", 
                                capture_fx_graph_after=["remove_num_users_is_0_nodes"],
                                logging_dir="/home/profile/logging/moe", engine_builder_monitor=False,):
    trt_mod = torchtrt.dynamo.compile(
        ep,
        inputs,
        enabled_precisions={torch.float16},
        min_block_size=1,
        use_explicit_typing=False,
        use_fp32_acc=False,
        disable_tf32=True,
    )

print((trt_mod(*inputs) - torch_output).mean())
print()

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions