Skip to content

Commit 651f4c8

Browse files
committed
Optimized index converter
1 parent 0aa4e55 commit 651f4c8

File tree

1 file changed

+11
-9
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+11
-9
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -188,15 +188,17 @@ def index(
188188
)
189189
else:
190190
dim_tensor_shape_mult_d1 = transpose_tensor_shape[i]
191-
mult_d1 = convert_binary_elementwise(
192-
ctx,
193-
target,
194-
source_ir,
195-
name + f"_shape_{i}",
196-
trt.ElementWiseOperation.PROD,
197-
mult_d1,
198-
dim_tensor_shape_mult_d1,
199-
)
191+
192+
if isinstance(dim_tensor_shape_mult_d1, TRTTensor):
193+
mult_d1 = convert_binary_elementwise(
194+
ctx,
195+
target,
196+
source_ir,
197+
name + f"_shape_{i}",
198+
trt.ElementWiseOperation.PROD,
199+
mult_d1,
200+
dim_tensor_shape_mult_d1,
201+
)
200202

201203
concat_tensor_layer = ctx.net.add_concatenation(
202204
[

0 commit comments

Comments
 (0)