Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 34 additions & 23 deletions py/torch_tensorrt/dynamo/conversion/impl/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,12 @@ def embedding_bag_with_ITensor_offsets(
mode: int,
include_last_offset: bool,
) -> Tuple[TRTTensor, TRTTensor, TRTTensor, TRTTensor]:
len_embed = embed.shape[0]
# prepare some tensors for future use
constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0")
constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1")

embed_shape = ctx.net.add_shape(embed).get_output(0)
len_embed = ctx.net.add_gather(embed_shape, constant_0, 0).get_output(0)

if include_last_offset:
# modify the last index of offsets to the end index
Expand All @@ -166,32 +171,38 @@ def embedding_bag_with_ITensor_offsets(
# create a placeholder tensor, whose shape is the same as an embedding
# if mode is 0 (sum) or 1 (mean), the placeholder tensor is filled with zeros
# if mode is 2 (max), the placeholder tensor is filled with negative infinity
placeholder_tensor = (
get_trt_tensor(
placeholder_tensor = impl.elementwise.mul(
ctx, target, source_ir, f"{name}_zero_tensors", embed, 0
)
zero_tensor = ctx.net.add_gather(placeholder_tensor, constant_0, 0).get_output(0)
zero_tensor = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_zero_tensor", zero_tensor, (-1,)
)

if mode == 2:
placeholder_tensor = impl.elementwise.add(
ctx,
np.full(embed.shape, -np.inf, dtype=np.float32),
target,
source_ir,
f"{name}_negative_inf_tensor",
placeholder_tensor,
-np.inf,
)
if mode == 2
else get_trt_tensor(
ctx, np.zeros(embed.shape, dtype=np.float32), f"{name}_zero_tensors"
)
)

# prepare some tensors for future use
zero_tensor = get_trt_tensor(
ctx, np.zeros((embed.shape[1],), dtype=np.float32), f"{name}_zero_tensor"
)
constant_0 = get_trt_tensor(ctx, 0, f"{name}_constant_tensor_0")
constant_1 = get_trt_tensor(ctx, 1, f"{name}_constant_tensor_1")

# Use two for loops to calculate the embedding of each bag
###### Outer loop: traverse offsets ######
loop1 = ctx.net.add_loop()
trip_limit1 = ctx.net.add_constant(
shape=(),
weights=trt.Weights(np.array([offsets.shape[0] - 1], dtype=np.int32)),
).get_output(0)

offsets_shape = ctx.net.add_shape(offsets).get_output(0)
offsets_size = ctx.net.add_gather(offsets_shape, constant_0, 0).get_output(0)
trip_limit1 = impl.elementwise.sub(
ctx, target, source_ir, f"{name}_trip_limit1", offsets_size, 1
)
# change to 0d tensor
trip_limit1 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_trip_limit1", trip_limit1, ()
)

loop1.add_trip_limit(trip_limit1, trt.TripLimit.COUNT)

rec1_i_tensor = loop1.add_recurrence(constant_1)
Expand All @@ -210,9 +221,9 @@ def embedding_bag_with_ITensor_offsets(

###### Inner loop: traverse indices ######
loop2 = ctx.net.add_loop()
trip_limit2 = ctx.net.add_constant(
shape=(), weights=trt.Weights(np.array([len_embed], dtype=np.int32))
).get_output(0)
trip_limit2 = impl.shuffle.reshape(
ctx, target, source_ir, f"{name}_reshape_trip_limit2", len_embed, ()
)
loop2.add_trip_limit(trip_limit2, trt.TripLimit.COUNT)
rec2_j_tensor = loop2.add_recurrence(constant_0)
set_layer_name(rec2_j_tensor, target, f"{name}_rec2_j_tensor", source_ir)
Expand Down
12 changes: 8 additions & 4 deletions tests/py/dynamo/conversion/test_embedding_bag_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,16 +438,20 @@ def forward(self, weight, indices, offsets):
weights=torch.randn((5, 2), dtype=torch.float32),
# weights_1 is for inference
weights_1=torch.randn((6, 3), dtype=torch.float32),
indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
dynamic_shapes={
"weights": {
0: torch.export.Dim("dyn_dim", min=2, max=8),
1: torch.export.Dim("dyn_dim_1", min=1, max=3),
},
"indices": {},
"offsets": {},
"indices": {
0: torch.export.Dim("dyn_dim_in", min=2, max=32),
},
"offsets": {
0: torch.export.Dim("dyn_dim_off", min=2, max=32),
},
},
indices=torch.tensor([1, 2, 4, 2, 3, 4], dtype=torch.int32),
offsets=torch.tensor([0, 2, 4], dtype=torch.int32),
mode=1,
per_sample_weights=None,
),
Expand Down
Loading