Skip to content

Commit ed80070

Browse files
committed
Enabled Qwen MoE with 1 layer. Rewrote index_put converter
1 parent 9241476 commit ed80070

File tree

4 files changed

+185
-91
lines changed

4 files changed

+185
-91
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,7 @@ def aten_ops_select(
883883

884884
@dynamo_tensorrt_converter(
885885
torch.ops.aten.index_put.default,
886+
supports_dynamic_shapes=True,
886887
)
887888
@enforce_tensor_types(
888889
{
@@ -3163,7 +3164,9 @@ def aten_ops_upsample_bicubic2d(
31633164

31643165

31653166
@dynamo_tensorrt_converter(
3166-
torch.ops.aten.topk.default, capability_validator=topk_validator
3167+
torch.ops.aten.topk.default,
3168+
capability_validator=topk_validator,
3169+
supports_dynamic_shapes=True,
31673170
)
31683171
@enforce_tensor_types(
31693172
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 123 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -502,22 +502,40 @@ def index_put_converter(
502502
K = len(I)
503503
# Determine the maximum size 'N' among the index tensors
504504
if K > 0:
505-
index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None]
505+
index_shapes = (
506+
[]
507+
) # [tensor.shape[0] for tensor in indices if tensor is not None]
508+
for idx_tensor in indices:
509+
if idx_tensor is not None:
510+
if idx_tensor.shape[0] != DYNAMIC_DIM:
511+
index_shapes.append(idx_tensor.shape[0])
512+
else:
513+
index_shapes.append(
514+
get_shape(
515+
ctx,
516+
target,
517+
source_ir,
518+
name + "idx_shape_dim_0",
519+
idx_tensor,
520+
0,
521+
)
522+
)
506523
N = max(index_shapes) if index_shapes else 1
507524
else:
508525
N = 1
509526

510527
# Compute shapes and volume for the free dimensions
511528
F_shapes = [input_tensor.shape[i] for i in F]
529+
assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported"
512530
F_volume = trt.volume(F_shapes) if F_shapes else 1
513531

514532
# Process indexed dimensions (I)
515533
I_tensors = []
516534
for i in I:
517535
idx = indices[i]
518536
assert idx is not None
519-
idx_reshaped = impl.shuffle.reshape(
520-
ctx, target, source_ir, f"{name}_reshape_idx_I_{i}", idx, (idx.shape[0], 1)
537+
idx_reshaped = impl.unsqueeze.unsqueeze(
538+
ctx, target, source_ir, f"{name}_unsqueeze_idx_I_{i}", idx, 1
521539
)
522540
expanded_idx = impl.slice.expand(
523541
ctx,
@@ -539,46 +557,50 @@ def index_put_converter(
539557
)
540558
arange_tensors.append(arange_tensor)
541559

542-
meshgrid_tensors = []
543-
for i, arange in enumerate(arange_tensors):
544-
reshape_shape = [1] * len(F)
545-
reshape_shape[i] = F_shapes[i]
546-
arange_reshaped = impl.shuffle.reshape(
547-
ctx,
548-
target,
549-
source_ir,
550-
f"{name}_reshape_arange_F_{F[i]}",
551-
arange,
552-
tuple(reshape_shape),
553-
)
554-
expanded_arange = impl.slice.expand(
555-
ctx,
556-
target,
557-
source_ir,
558-
f"{name}_expand_arange_F_{F[i]}",
559-
arange_reshaped,
560-
tuple(F_shapes),
561-
)
562-
meshgrid_tensors.append(expanded_arange)
563-
564-
meshgrid_stacked = impl.cat.cat(
565-
ctx,
566-
target,
567-
source_ir,
568-
f"{name}_stack_meshgrid",
569-
[
570-
impl.shuffle.reshape(
560+
if len(arange_tensors) == 1:
561+
# No need to stack
562+
meshgrid_stacked = arange_tensors[0]
563+
else:
564+
meshgrid_tensors = []
565+
for i, arange in enumerate(arange_tensors):
566+
reshape_shape = [1] * len(F)
567+
reshape_shape[i] = F_shapes[i]
568+
arange_reshaped = impl.shuffle.reshape(
571569
ctx,
572570
target,
573571
source_ir,
574-
f"{name}_reshape_mesh_{i}",
575-
t,
576-
(*F_shapes, 1),
572+
f"{name}_reshape_arange_F_{F[i]}",
573+
arange,
574+
tuple(reshape_shape),
577575
)
578-
for i, t in enumerate(meshgrid_tensors)
579-
],
580-
dim=-1,
581-
)
576+
expanded_arange = impl.slice.expand(
577+
ctx,
578+
target,
579+
source_ir,
580+
f"{name}_expand_arange_F_{F[i]}",
581+
arange_reshaped,
582+
tuple(F_shapes),
583+
)
584+
meshgrid_tensors.append(expanded_arange)
585+
586+
meshgrid_stacked = impl.cat.cat(
587+
ctx,
588+
target,
589+
source_ir,
590+
f"{name}_stack_meshgrid",
591+
[
592+
impl.shuffle.reshape(
593+
ctx,
594+
target,
595+
source_ir,
596+
f"{name}_reshape_mesh_{i}",
597+
t,
598+
(*F_shapes, 1),
599+
)
600+
for i, t in enumerate(meshgrid_tensors)
601+
],
602+
dim=-1,
603+
)
582604
meshgrid_reshaped = impl.shuffle.reshape(
583605
ctx,
584606
target,
@@ -603,21 +625,15 @@ def index_put_converter(
603625

604626
# Combine all indexed dimensions (I)
605627
if K > 0:
606-
I_combined = impl.cat.cat(
607-
ctx,
608-
target,
609-
source_ir,
610-
f"{name}_cat_I",
611-
[
612-
impl.shuffle.reshape(
613-
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
614-
)
615-
for i, t in enumerate(I_tensors)
616-
],
617-
dim=2,
618-
)
628+
629+
I_combined = [
630+
impl.shuffle.reshape(
631+
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
632+
)
633+
for i, t in enumerate(I_tensors)
634+
]
619635
else:
620-
I_combined = None
636+
I_combined = []
621637

622638
# Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
623639
ii_list = []
@@ -626,24 +642,12 @@ def index_put_converter(
626642
for dim in range(rank):
627643
unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}"
628644
if dim in I:
629-
start = [0, 0, i_idx]
630-
shape = [N, F_volume, 1]
631-
stride = [1, 1, 1]
632-
idx_tensor = impl.slice.slice(
633-
ctx,
634-
target,
635-
source_ir,
636-
f"{name}_slice_I_dim_{unique_suffix}",
637-
I_combined,
638-
start,
639-
shape,
640-
stride,
641-
)
645+
idx_tensor = I_combined[i]
642646
ii_list.append(idx_tensor)
643647
i_idx += 1
644648
else:
645649
start = [0, 0, f_idx]
646-
shape = [N, F_volume, 1]
650+
shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1]
647651
stride = [1, 1, 1]
648652
mesh_tensor = impl.slice.slice(
649653
ctx,
@@ -662,20 +666,24 @@ def index_put_converter(
662666
indices_cat = impl.cat.cat(
663667
ctx, target, source_ir, f"{name}_cat_indices", ii_list, dim=2
664668
)
669+
670+
# Flatten the indices_cat to (N * F_volume, rank)
665671
indices_cat = impl.shuffle.reshape(
666672
ctx,
667673
target,
668674
source_ir,
669675
f"{name}_reshape_indices_cat",
670676
indices_cat,
671-
(N * F_volume, rank),
677+
(-1, rank),
672678
)
673679

674680
if not isinstance(values, TRTTensor):
675681
values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0)
676682

677683
# Define the expected shape based on (N,) + F_shapes
678-
expected_shape = (N,) + tuple(F_shapes)
684+
expected_shape = (
685+
(-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes)
686+
)
679687

680688
# Broadcast 'values' to match the expected shape
681689
if len(values.shape) == 0 or values.shape == (1,): # Scalar case
@@ -773,16 +781,51 @@ def index_put_converter(
773781
source_ir,
774782
f"{name}_flatten_values",
775783
values_expanded,
776-
(N * F_volume,),
784+
(-1,),
777785
)
778-
779786
indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32")
780-
# Perform Scatter ND operation
781-
scatter_layer = ctx.net.add_scatter(
782-
input_tensor,
783-
indices_cat,
784-
flattened_values,
785-
trt.ScatterMode.ND if not accumulate else trt.ScatterMode.ND_ELEMENTWISE_ADD,
786-
)
787-
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
788-
return scatter_layer.get_output(0)
787+
if accumulate:
788+
zero_tensor = impl.full.full(
789+
ctx,
790+
target,
791+
source_ir,
792+
f"{name}_zero_tensor",
793+
[
794+
get_shape(
795+
ctx,
796+
target,
797+
source_ir,
798+
name + f"input_tensor_shape_dim_{i}",
799+
input_tensor,
800+
i,
801+
)
802+
for i in range(len(input_tensor.shape))
803+
],
804+
0.0,
805+
dtype=input_tensor.dtype,
806+
)
807+
# Perform Scatter ND operation
808+
scatter_layer = ctx.net.add_scatter(
809+
zero_tensor,
810+
indices_cat,
811+
flattened_values,
812+
trt.ScatterMode.ND,
813+
)
814+
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
815+
816+
scatter_out = scatter_layer.get_output(0)
817+
result = impl.elementwise.add(
818+
ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor
819+
)
820+
return result
821+
822+
else:
823+
scatter_layer = ctx.net.add_scatter(
824+
input_tensor,
825+
indices_cat,
826+
flattened_values,
827+
trt.ScatterMode.ND,
828+
)
829+
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
830+
scatter_out = scatter_layer.get_output(0)
831+
return scatter_out

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch_tensorrt as torchtrt
23
from parameterized import param, parameterized
34
from torch.testing._internal.common_utils import run_tests
45

@@ -244,6 +245,57 @@ def forward(self, source_tensor, value_tensor):
244245
use_dynamo_tracer=True,
245246
)
246247

248+
def test_index_add_dynamic_shape(self):
249+
250+
class Model(torch.nn.Module):
251+
def forward(self, x, y, z, a, b):
252+
x.index_add_(0, y, z)
253+
x.index_add_(0, a, b)
254+
return x
255+
256+
dim = 10
257+
model = Model().cuda()
258+
inputs = [
259+
torch.ones((12, dim)).half().cuda(),
260+
torch.tensor([0, 1]).cuda(),
261+
torch.randn((2, dim)).half().cuda(),
262+
torch.tensor([2, 9, 11]).cuda(),
263+
torch.randn((3, dim)).half().cuda(),
264+
]
265+
torch_output = model.cuda().forward(*inputs)
266+
seq_len1 = torch.export.Dim("seq_len1", min=1, max=128)
267+
seq_len2 = torch.export.Dim("seq_len2", min=1, max=128)
268+
seq_len3 = torch.export.Dim("seq_len3", min=1, max=128)
269+
270+
ep = torch.export.export(
271+
model,
272+
tuple(inputs),
273+
dynamic_shapes=(
274+
{0: seq_len1},
275+
{0: seq_len2},
276+
{0: seq_len2},
277+
{0: seq_len3},
278+
{0: seq_len3},
279+
),
280+
)
281+
with torchtrt.dynamo.Debugger(
282+
log_level="debug",
283+
capture_fx_graph_after=["remove_num_users_is_0_nodes"],
284+
logging_dir="/home/profile/logging/moe",
285+
engine_builder_monitor=False,
286+
):
287+
trt_mod = torchtrt.dynamo.compile(
288+
ep,
289+
inputs,
290+
enabled_precisions={torch.float16},
291+
min_block_size=1,
292+
use_explicit_typing=False,
293+
use_fp32_acc=False,
294+
disable_tf32=True,
295+
)
296+
result = trt_mod(*inputs)
297+
assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
298+
247299

248300
if __name__ == "__main__":
249301
run_tests()

tools/llm/run_llm.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,11 @@ def get_model(args):
4949
moved to CUDA device with the specified precision
5050
"""
5151
with torch.no_grad():
52-
model = (
53-
AutoModelForCausalLM.from_pretrained(
54-
args.model,
55-
use_cache=False,
56-
attn_implementation="sdpa",
57-
)
58-
.eval()
59-
.cuda()
60-
)
52+
model = AutoModelForCausalLM.from_pretrained(
53+
args.model,
54+
use_cache=False,
55+
attn_implementation="sdpa",
56+
).eval()
6157

6258
if args.precision == "FP16":
6359
model = model.to(torch.float16)
@@ -66,7 +62,7 @@ def get_model(args):
6662
else:
6763
model = model.to(torch.float32)
6864

69-
return model
65+
return model.cuda()
7066

7167

7268
def compile_torchtrt(model, input_ids, args):

0 commit comments

Comments
 (0)