From 5b2e8c481af00a55e2fa2a4456000e9a4a1e0369 Mon Sep 17 00:00:00 2001 From: Evan Li Date: Wed, 27 Aug 2025 15:23:36 -0700 Subject: [PATCH] support dynamics for all inputs --- .../dynamo/conversion/impl/embedding.py | 57 +++++++++++-------- .../conversion/test_embedding_bag_aten.py | 12 ++-- 2 files changed, 42 insertions(+), 27 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py index a712641f44..0a723618eb 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/embedding.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/embedding.py @@ -149,7 +149,12 @@ def embedding_bag_with_ITensor_offsets( mode: int, include_last_offset: bool, ) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]: - len_embed = embed.shape[0] + # prepare some tensors for future use + constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0") + constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1") + + embed_shape = ctx.net.add_shape(embed).get_output(0) + len_embed = ctx.net.add_gather(embed_shape, constant_0, 0).get_output(0) if include_last_offset: # modify the last index of offsets to the end index @@ -166,32 +171,38 @@ def embedding_bag_with_ITensor_offsets( # create a placeholder tensor, whose shape is the same as an embedding # if mode is 0 (sum) or 1 (mean), the placeholder tensor is filled with zeros # if mode is 2 (max), the placeholder tensor is filled with negative infinity - placeholder_tensor = ( - get_trt_tensor( + placeholder_tensor = impl.elementwise.mul( + ctx, target, source_ir, f"{name}_zero_tensors", embed, 0 + ) + zero_tensor = ctx.net.add_gather(placeholder_tensor, constant_0, 0).get_output(0) + zero_tensor = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_zero_tensor", zero_tensor, (-1,) + ) + + if mode == 2: + placeholder_tensor = impl.elementwise.add( ctx, - np.full(embed.shape, -np.inf, dtype=np.float32), + target, + source_ir, f"{name}_negative_inf_tensor", + placeholder_tensor, + -np.inf, ) - if mode == 2 - else get_trt_tensor( - ctx, np.zeros(embed.shape, dtype=np.float32), f"{name}_zero_tensors" - ) - ) - - # prepare some tensors for future use - zero_tensor = get_trt_tensor( - ctx, np.zeros((embed.shape[1],), dtype=np.float32), f"{name}_zero_tensor" - ) - constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0") - constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1") # Use two for loops to calculate the embedding of each bag ###### Outer loop: traverse offsets ###### loop1 = ctx.net.add_loop() - trip_limit1 = ctx.net.add_constant( - shape=(), - weights=trt.Weights(np.array([offsets.shape[0] - 1], dtype=np.int32)), - ).get_output(0) + + offsets_shape = ctx.net.add_shape(offsets).get_output(0) + offsets_size = ctx.net.add_gather(offsets_shape, constant_0, 0).get_output(0) + trip_limit1 = impl.elementwise.sub( + ctx, target, source_ir, f"{name}_trip_limit1", offsets_size, 1 + ) + # change to 0d tensor + trip_limit1 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_trip_limit1", trip_limit1, () + ) + loop1.add_trip_limit(trip_limit1, trt.TripLimit.COUNT) rec1_i_tensor = loop1.add_recurrence(constant_1) @@ -210,9 +221,9 @@ def embedding_bag_with_ITensor_offsets( ###### Inner loop: traverse indices ###### loop2 = ctx.net.add_loop() - trip_limit2 = ctx.net.add_constant( - shape=(), weights=trt.Weights(np.array([len_embed], dtype=np.int32)) - ).get_output(0) + trip_limit2 = impl.shuffle.reshape( + ctx, target, source_ir, f"{name}_reshape_trip_limit2", len_embed, () + ) loop2.add_trip_limit(trip_limit2, trt.TripLimit.COUNT) rec2_j_tensor = loop2.add_recurrence(constant_0) set_layer_name(rec2_j_tensor, target, f"{name}_rec2_j_tensor", source_ir) diff --git a/tests/py/dynamo/conversion/test_embedding_bag_aten.py b/tests/py/dynamo/conversion/test_embedding_bag_aten.py index 1f119bd77e..122cc0b7fe 100644 --- a/tests/py/dynamo/conversion/test_embedding_bag_aten.py +++ b/tests/py/dynamo/conversion/test_embedding_bag_aten.py @@ -438,16 +438,20 @@ def forward(self, weight, indices, offsets): weights=torch.randn((5, 2), dtype=torch.float32), # weights_1 is for inference weights_1=torch.randn((6, 3), dtype=torch.float32), + indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32), + offsets=torch.tensor([0, 2, 4], dtype=torch.int32), dynamic_shapes={ "weights": { 0: torch.export.Dim("dyn_dim", min=2, max=8), 1: torch.export.Dim("dyn_dim_1", min=1, max=3), }, - "indices": {}, - "offsets": {}, + "indices": { + 0: torch.export.Dim("dyn_dim_in", min=2, max=32), + }, + "offsets": { + 0: torch.export.Dim("dyn_dim_off", min=2, max=32), + }, }, - indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32), - offsets=torch.tensor([0, 2, 4], dtype=torch.int32), mode=1, per_sample_weights=None, ),