Skip to content

Commit a016bc0

Browse files
committed
Optimized index converter
1 parent 2140c49 commit a016bc0

File tree

1 file changed

+11
-9
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl

1 file changed

+11
-9
lines changed

py/torch_tensorrt/dynamo/conversion/impl/select.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -257,15 +257,17 @@ def index(
257257
)
258258
else:
259259
dim_tensor_shape_mult_d1 = transpose_tensor_shape[i]
260-
mult_d1 = convert_binary_elementwise(
261-
ctx,
262-
target,
263-
source_ir,
264-
name + f"_shape_{i}",
265-
trt.ElementWiseOperation.PROD,
266-
mult_d1,
267-
dim_tensor_shape_mult_d1,
268-
)
260+
261+
if isinstance(dim_tensor_shape_mult_d1, TRTTensor):
262+
mult_d1 = convert_binary_elementwise(
263+
ctx,
264+
target,
265+
source_ir,
266+
name + f"_shape_{i}",
267+
trt.ElementWiseOperation.PROD,
268+
mult_d1,
269+
dim_tensor_shape_mult_d1,
270+
)
269271

270272
concat_tensor_layer = ctx.net.add_concatenation(
271273
[

0 commit comments

Comments
 (0)