Skip to content

Commit 754743b

Browse files
committed
Enabled Qwen MoE with 1 layer. Rewrote index_put converter
1 parent 1b23905 commit 754743b

File tree

4 files changed

+180
-82
lines changed

4 files changed

+180
-82
lines changed

py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -888,6 +888,7 @@ def aten_ops_select(
888888

889889
@dynamo_tensorrt_converter(
890890
torch.ops.aten.index_put.default,
891+
supports_dynamic_shapes=True,
891892
)
892893
@enforce_tensor_types(
893894
{
@@ -3168,7 +3169,9 @@ def aten_ops_upsample_bicubic2d(
31683169

31693170

31703171
@dynamo_tensorrt_converter(
3171-
torch.ops.aten.topk.default, capability_validator=topk_validator
3172+
torch.ops.aten.topk.default,
3173+
capability_validator=topk_validator,
3174+
supports_dynamic_shapes=True,
31723175
)
31733176
@enforce_tensor_types(
31743177
{

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 123 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -571,22 +571,40 @@ def index_put_converter(
571571
K = len(I)
572572
# Determine the maximum size 'N' among the index tensors
573573
if K > 0:
574-
index_shapes = [tensor.shape[0] for tensor in indices if tensor is not None]
574+
index_shapes = (
575+
[]
576+
) # [tensor.shape[0] for tensor in indices if tensor is not None]
577+
for idx_tensor in indices:
578+
if idx_tensor is not None:
579+
if idx_tensor.shape[0] != DYNAMIC_DIM:
580+
index_shapes.append(idx_tensor.shape[0])
581+
else:
582+
index_shapes.append(
583+
get_shape(
584+
ctx,
585+
target,
586+
source_ir,
587+
name + "idx_shape_dim_0",
588+
idx_tensor,
589+
0,
590+
)
591+
)
575592
N = max(index_shapes) if index_shapes else 1
576593
else:
577594
N = 1
578595

579596
# Compute shapes and volume for the free dimensions
580597
F_shapes = [input_tensor.shape[i] for i in F]
598+
assert -1 not in F_shapes, "Dynamic shape in free dimensions is not supported"
581599
F_volume = trt.volume(F_shapes) if F_shapes else 1
582600

583601
# Process indexed dimensions (I)
584602
I_tensors = []
585603
for i in I:
586604
idx = indices[i]
587605
assert idx is not None
588-
idx_reshaped = impl.shuffle.reshape(
589-
ctx, target, source_ir, f"{name}_reshape_idx_I_{i}", idx, (idx.shape[0], 1)
606+
idx_reshaped = impl.unsqueeze.unsqueeze(
607+
ctx, target, source_ir, f"{name}_unsqueeze_idx_I_{i}", idx, 1
590608
)
591609
expanded_idx = impl.slice.expand(
592610
ctx,
@@ -608,46 +626,50 @@ def index_put_converter(
608626
)
609627
arange_tensors.append(arange_tensor)
610628

611-
meshgrid_tensors = []
612-
for i, arange in enumerate(arange_tensors):
613-
reshape_shape = [1] * len(F)
614-
reshape_shape[i] = F_shapes[i]
615-
arange_reshaped = impl.shuffle.reshape(
616-
ctx,
617-
target,
618-
source_ir,
619-
f"{name}_reshape_arange_F_{F[i]}",
620-
arange,
621-
tuple(reshape_shape),
622-
)
623-
expanded_arange = impl.slice.expand(
624-
ctx,
625-
target,
626-
source_ir,
627-
f"{name}_expand_arange_F_{F[i]}",
628-
arange_reshaped,
629-
tuple(F_shapes),
630-
)
631-
meshgrid_tensors.append(expanded_arange)
632-
633-
meshgrid_stacked = impl.cat.cat(
634-
ctx,
635-
target,
636-
source_ir,
637-
f"{name}_stack_meshgrid",
638-
[
639-
impl.shuffle.reshape(
629+
if len(arange_tensors) == 1:
630+
# No need to stack
631+
meshgrid_stacked = arange_tensors[0]
632+
else:
633+
meshgrid_tensors = []
634+
for i, arange in enumerate(arange_tensors):
635+
reshape_shape = [1] * len(F)
636+
reshape_shape[i] = F_shapes[i]
637+
arange_reshaped = impl.shuffle.reshape(
640638
ctx,
641639
target,
642640
source_ir,
643-
f"{name}_reshape_mesh_{i}",
644-
t,
645-
(*F_shapes, 1),
641+
f"{name}_reshape_arange_F_{F[i]}",
642+
arange,
643+
tuple(reshape_shape),
646644
)
647-
for i, t in enumerate(meshgrid_tensors)
648-
],
649-
dim=-1,
650-
)
645+
expanded_arange = impl.slice.expand(
646+
ctx,
647+
target,
648+
source_ir,
649+
f"{name}_expand_arange_F_{F[i]}",
650+
arange_reshaped,
651+
tuple(F_shapes),
652+
)
653+
meshgrid_tensors.append(expanded_arange)
654+
655+
meshgrid_stacked = impl.cat.cat(
656+
ctx,
657+
target,
658+
source_ir,
659+
f"{name}_stack_meshgrid",
660+
[
661+
impl.shuffle.reshape(
662+
ctx,
663+
target,
664+
source_ir,
665+
f"{name}_reshape_mesh_{i}",
666+
t,
667+
(*F_shapes, 1),
668+
)
669+
for i, t in enumerate(meshgrid_tensors)
670+
],
671+
dim=-1,
672+
)
651673
meshgrid_reshaped = impl.shuffle.reshape(
652674
ctx,
653675
target,
@@ -672,21 +694,15 @@ def index_put_converter(
672694

673695
# Combine all indexed dimensions (I)
674696
if K > 0:
675-
I_combined = impl.cat.cat(
676-
ctx,
677-
target,
678-
source_ir,
679-
f"{name}_cat_I",
680-
[
681-
impl.shuffle.reshape(
682-
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
683-
)
684-
for i, t in enumerate(I_tensors)
685-
],
686-
dim=2,
687-
)
697+
698+
I_combined = [
699+
impl.shuffle.reshape(
700+
ctx, target, source_ir, f"{name}_reshape_I_{i}", t, (N, F_volume, 1)
701+
)
702+
for i, t in enumerate(I_tensors)
703+
]
688704
else:
689-
I_combined = None
705+
I_combined = []
690706

691707
# Build the final index list (ii_list) by slicing either I_combined or meshgrid_expanded
692708
ii_list = []
@@ -695,24 +711,12 @@ def index_put_converter(
695711
for dim in range(rank):
696712
unique_suffix = f"{dim}_{i_idx if dim in I else f_idx}"
697713
if dim in I:
698-
start = [0, 0, i_idx]
699-
shape = [N, F_volume, 1]
700-
stride = [1, 1, 1]
701-
idx_tensor = impl.slice.slice(
702-
ctx,
703-
target,
704-
source_ir,
705-
f"{name}_slice_I_dim_{unique_suffix}",
706-
I_combined,
707-
start,
708-
shape,
709-
stride,
710-
)
714+
idx_tensor = I_combined[i]
711715
ii_list.append(idx_tensor)
712716
i_idx += 1
713717
else:
714718
start = [0, 0, f_idx]
715-
shape = [N, F_volume, 1]
719+
shape = [-1, F_volume, 1] if isinstance(N, TRTTensor) else [N, F_volume, 1]
716720
stride = [1, 1, 1]
717721
mesh_tensor = impl.slice.slice(
718722
ctx,
@@ -731,20 +735,24 @@ def index_put_converter(
731735
indices_cat = impl.cat.cat(
732736
ctx, target, source_ir, f"{name}_cat_indices", ii_list, dim=2
733737
)
738+
739+
# Flatten the indices_cat to (N * F_volume, rank)
734740
indices_cat = impl.shuffle.reshape(
735741
ctx,
736742
target,
737743
source_ir,
738744
f"{name}_reshape_indices_cat",
739745
indices_cat,
740-
(N * F_volume, rank),
746+
(-1, rank),
741747
)
742748

743749
if not isinstance(values, TRTTensor):
744750
values = get_trt_tensor(ctx, values, f"{name}_values", min_rank=0)
745751

746752
# Define the expected shape based on (N,) + F_shapes
747-
expected_shape = (N,) + tuple(F_shapes)
753+
expected_shape = (
754+
(-1,) + tuple(F_shapes) if isinstance(N, TRTTensor) else (N,) + tuple(F_shapes)
755+
)
748756

749757
# Broadcast 'values' to match the expected shape
750758
if len(values.shape) == 0 or values.shape == (1,): # Scalar case
@@ -842,16 +850,51 @@ def index_put_converter(
842850
source_ir,
843851
f"{name}_flatten_values",
844852
values_expanded,
845-
(N * F_volume,),
853+
(-1,),
846854
)
847-
848855
indices_cat = cast_trt_tensor(ctx, indices_cat, trt.int32, f"{name}_idx_int32")
849-
# Perform Scatter ND operation
850-
scatter_layer = ctx.net.add_scatter(
851-
input_tensor,
852-
indices_cat,
853-
flattened_values,
854-
trt.ScatterMode.ND if not accumulate else trt.ScatterMode.ND_ELEMENTWISE_ADD,
855-
)
856-
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
857-
return scatter_layer.get_output(0)
856+
if accumulate:
857+
zero_tensor = impl.full.full(
858+
ctx,
859+
target,
860+
source_ir,
861+
f"{name}_zero_tensor",
862+
[
863+
get_shape(
864+
ctx,
865+
target,
866+
source_ir,
867+
name + f"input_tensor_shape_dim_{i}",
868+
input_tensor,
869+
i,
870+
)
871+
for i in range(len(input_tensor.shape))
872+
],
873+
0.0,
874+
dtype=input_tensor.dtype,
875+
)
876+
# Perform Scatter ND operation
877+
scatter_layer = ctx.net.add_scatter(
878+
zero_tensor,
879+
indices_cat,
880+
flattened_values,
881+
trt.ScatterMode.ND,
882+
)
883+
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
884+
885+
scatter_out = scatter_layer.get_output(0)
886+
result = impl.elementwise.add(
887+
ctx, target, source_ir, f"{name}_add", scatter_out, input_tensor
888+
)
889+
return result
890+
891+
else:
892+
scatter_layer = ctx.net.add_scatter(
893+
input_tensor,
894+
indices_cat,
895+
flattened_values,
896+
trt.ScatterMode.ND,
897+
)
898+
set_layer_name(scatter_layer, target, f"{name}_scatter", source_ir)
899+
scatter_out = scatter_layer.get_output(0)
900+
return scatter_out

tests/py/dynamo/conversion/test_index_put_aten.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import torch
2+
import torch_tensorrt as torchtrt
23
from parameterized import param, parameterized
34
from torch.testing._internal.common_utils import run_tests
45

@@ -244,6 +245,57 @@ def forward(self, source_tensor, value_tensor):
244245
use_dynamo_tracer=True,
245246
)
246247

248+
def test_index_add_dynamic_shape(self):
249+
250+
class Model(torch.nn.Module):
251+
def forward(self, x, y, z, a, b):
252+
x.index_add_(0, y, z)
253+
x.index_add_(0, a, b)
254+
return x
255+
256+
dim = 10
257+
model = Model().cuda()
258+
inputs = [
259+
torch.ones((12, dim)).half().cuda(),
260+
torch.tensor([0, 1]).cuda(),
261+
torch.randn((2, dim)).half().cuda(),
262+
torch.tensor([2, 9, 11]).cuda(),
263+
torch.randn((3, dim)).half().cuda(),
264+
]
265+
torch_output = model.cuda().forward(*inputs)
266+
seq_len1 = torch.export.Dim("seq_len1", min=1, max=128)
267+
seq_len2 = torch.export.Dim("seq_len2", min=1, max=128)
268+
seq_len3 = torch.export.Dim("seq_len3", min=1, max=128)
269+
270+
ep = torch.export.export(
271+
model,
272+
tuple(inputs),
273+
dynamic_shapes=(
274+
{0: seq_len1},
275+
{0: seq_len2},
276+
{0: seq_len2},
277+
{0: seq_len3},
278+
{0: seq_len3},
279+
),
280+
)
281+
with torchtrt.dynamo.Debugger(
282+
log_level="debug",
283+
capture_fx_graph_after=["remove_num_users_is_0_nodes"],
284+
logging_dir="/home/profile/logging/moe",
285+
engine_builder_monitor=False,
286+
):
287+
trt_mod = torchtrt.dynamo.compile(
288+
ep,
289+
inputs,
290+
enabled_precisions={torch.float16},
291+
min_block_size=1,
292+
use_explicit_typing=False,
293+
use_fp32_acc=False,
294+
disable_tf32=True,
295+
)
296+
result = trt_mod(*inputs)
297+
assert torch.allclose(result, torch_output, atol=1e-4, rtol=1e-4)
298+
247299

248300
if __name__ == "__main__":
249301
run_tests()

tools/llm/run_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def get_model(args):
7171
else:
7272
model = model.to(torch.float32)
7373

74-
return model
74+
return model.cuda()
7575

7676

7777
def compile_torchtrt(model, input_ids, args):

0 commit comments

Comments
 (0)