Skip to content

Commit 36656b6

Browse files
authored
Merge pull request #847 from chaoz-dev/chaoz/fix-compare-intwrapper
[torch2trt/torch2trt.py] Resolve issue #846: Add support for using `._trt` when encountering an IntWrapper scalar.
2 parents b1a0360 + 80ab8b8 commit 36656b6

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

torch2trt/torch2trt.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -153,12 +153,12 @@ def add_missing_trt_tensors(network, tensors):
153153

154154
# get tensor w/ _trt
155155
# or... add constant for scalar primitive
156-
if isinstance(t, float) or isinstance(t, int):
156+
if hasattr(t, "_trt") or isinstance(t, IntWrapper):
157+
trt_tensor = t._trt
158+
elif isinstance(t, float) or isinstance(t, int):
157159
shape = (1,)
158160
scalar = t * torch.ones(shape, dtype=dtype).cpu().numpy()
159161
trt_tensor = network.add_constant(shape, scalar).get_output(0)
160-
elif hasattr(t, "_trt"):
161-
trt_tensor = t._trt
162162

163163
# or... add constant for leaf tensor w/o _trt
164164
else:
@@ -232,7 +232,7 @@ def trt_(network, *tensors):
232232
# GET TRT TENSOR (OR CREATE TRT CONSTANT)
233233

234234
# get tensor w/ _trt
235-
if isinstance(t, torch.Tensor) and hasattr(t, "_trt"):
235+
if (isinstance(t, torch.Tensor) and hasattr(t, "_trt")) or isinstance(t, IntWrapper):
236236
trt_tensor = t._trt
237237

238238
# or... add constant for leaf tensor w/o _trt

0 commit comments

Comments
 (0)