Skip to content

Commit cacf653

Browse files
authored
feat: support dynamics for all inputs for embedding_bag converter (#3796)
1 parent b88c0e4 commit cacf653

File tree

2 files changed

+42
-27
lines changed

2 files changed

+42
-27
lines changed

py/torch_tensorrt/dynamo/conversion/impl/embedding.py

Lines changed: 34 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,12 @@ def embedding_bag_with_ITensor_offsets(
149149
mode: int,
150150
include_last_offset: bool,
151151
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
152-
len_embed = embed.shape[0]
152+
# prepare some tensors for future use
153+
constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0")
154+
constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1")
155+
156+
embed_shape = ctx.net.add_shape(embed).get_output(0)
157+
len_embed = ctx.net.add_gather(embed_shape, constant_0, 0).get_output(0)
153158

154159
if include_last_offset:
155160
# modify the last index of offsets to the end index
@@ -166,32 +171,38 @@ def embedding_bag_with_ITensor_offsets(
166171
# create a placeholder tensor, whose shape is the same as an embedding
167172
# if mode is 0 (sum) or 1 (mean), the placeholder tensor is filled with zeros
168173
# if mode is 2 (max), the placeholder tensor is filled with negative infinity
169-
placeholder_tensor = (
170-
get_trt_tensor(
174+
placeholder_tensor = impl.elementwise.mul(
175+
ctx, target, source_ir, f"{name}_zero_tensors", embed, 0
176+
)
177+
zero_tensor = ctx.net.add_gather(placeholder_tensor, constant_0, 0).get_output(0)
178+
zero_tensor = impl.shuffle.reshape(
179+
ctx, target, source_ir, f"{name}_reshape_zero_tensor", zero_tensor, (-1,)
180+
)
181+
182+
if mode == 2:
183+
placeholder_tensor = impl.elementwise.add(
171184
ctx,
172-
np.full(embed.shape, -np.inf, dtype=np.float32),
185+
target,
186+
source_ir,
173187
f"{name}_negative_inf_tensor",
188+
placeholder_tensor,
189+
-np.inf,
174190
)
175-
if mode == 2
176-
else get_trt_tensor(
177-
ctx, np.zeros(embed.shape, dtype=np.float32), f"{name}_zero_tensors"
178-
)
179-
)
180-
181-
# prepare some tensors for future use
182-
zero_tensor = get_trt_tensor(
183-
ctx, np.zeros((embed.shape[1],), dtype=np.float32), f"{name}_zero_tensor"
184-
)
185-
constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0")
186-
constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1")
187191

188192
# Use two for loops to calculate the embedding of each bag
189193
###### Outer loop: traverse offsets ######
190194
loop1 = ctx.net.add_loop()
191-
trip_limit1 = ctx.net.add_constant(
192-
shape=(),
193-
weights=trt.Weights(np.array([offsets.shape[0] - 1], dtype=np.int32)),
194-
).get_output(0)
195+
196+
offsets_shape = ctx.net.add_shape(offsets).get_output(0)
197+
offsets_size = ctx.net.add_gather(offsets_shape, constant_0, 0).get_output(0)
198+
trip_limit1 = impl.elementwise.sub(
199+
ctx, target, source_ir, f"{name}_trip_limit1", offsets_size, 1
200+
)
201+
# change to 0d tensor
202+
trip_limit1 = impl.shuffle.reshape(
203+
ctx, target, source_ir, f"{name}_reshape_trip_limit1", trip_limit1, ()
204+
)
205+
195206
loop1.add_trip_limit(trip_limit1, trt.TripLimit.COUNT)
196207

197208
rec1_i_tensor = loop1.add_recurrence(constant_1)
@@ -210,9 +221,9 @@ def embedding_bag_with_ITensor_offsets(
210221

211222
###### Inner loop: traverse indices ######
212223
loop2 = ctx.net.add_loop()
213-
trip_limit2 = ctx.net.add_constant(
214-
shape=(), weights=trt.Weights(np.array([len_embed], dtype=np.int32))
215-
).get_output(0)
224+
trip_limit2 = impl.shuffle.reshape(
225+
ctx, target, source_ir, f"{name}_reshape_trip_limit2", len_embed, ()
226+
)
216227
loop2.add_trip_limit(trip_limit2, trt.TripLimit.COUNT)
217228
rec2_j_tensor = loop2.add_recurrence(constant_0)
218229
set_layer_name(rec2_j_tensor, target, f"{name}_rec2_j_tensor", source_ir)

tests/py/dynamo/conversion/test_embedding_bag_aten.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -440,16 +440,20 @@ def forward(self, weight, indices, offsets):
440440
weights=torch.randn((5, 2), dtype=torch.float32),
441441
# weights_1 is for inference
442442
weights_1=torch.randn((6, 3), dtype=torch.float32),
443+
indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32),
444+
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
443445
dynamic_shapes={
444446
"weights": {
445447
0: torch.export.Dim("dyn_dim", min=2, max=8),
446448
1: torch.export.Dim("dyn_dim_1", min=1, max=3),
447449
},
448-
"indices": {},
449-
"offsets": {},
450+
"indices": {
451+
0: torch.export.Dim("dyn_dim_in", min=2, max=32),
452+
},
453+
"offsets": {
454+
0: torch.export.Dim("dyn_dim_off", min=2, max=32),
455+
},
450456
},
451-
indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32),
452-
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
453457
mode=1,
454458
per_sample_weights=None,
455459
),

0 commit comments

Comments
 (0)