-
Notifications
You must be signed in to change notification settings - Fork 370
Open
Labels
bugSomething isn't workingSomething isn't working
Description
import torch
import torch.nn as nn
import torch_tensorrt as torchtrt
class Model(nn.Module):
def forward(self, x, y, z, a, b):
x.index_add_(0, y, z)
# x = x.index_add_(0, a, b)
return x
model = Model().cuda()
inputs = [torch.randn((12, 2048)).half().cuda(),
torch.randint(0, 12, (5, )).cuda(), torch.randn((5, 2048)).half().cuda(),
torch.randint(0, 12, (3, )).cuda(), torch.randn((3, 2048)).half().cuda()]
torch_output = model.cuda().forward(*inputs)
seq_len1 = torch.export.Dim("seq_len1", min=1, max=128)
seq_len2 = torch.export.Dim("seq_len2", min=1, max=128)
seq_len3 = torch.export.Dim("seq_len3", min=1, max=128)
ep = torch.export.export(model, tuple(inputs), dynamic_shapes=({0: seq_len1}, {0: seq_len2}, {0: seq_len2}, {0: seq_len3}, {0: seq_len3}))
with torchtrt.dynamo.Debugger(log_level="debug",
capture_fx_graph_after=["remove_num_users_is_0_nodes"],
logging_dir="/home/profile/logging/moe", engine_builder_monitor=False,):
trt_mod = torchtrt.dynamo.compile(
ep,
inputs,
enabled_precisions={torch.float16},
min_block_size=1,
use_explicit_typing=False,
use_fp32_acc=False,
disable_tf32=True,
)
print((trt_mod(*inputs) - torch_output).mean())
print()
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working