File tree Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Expand file tree Collapse file tree 1 file changed +4
-4
lines changed Original file line number Diff line number Diff line change @@ -153,12 +153,12 @@ def add_missing_trt_tensors(network, tensors):
153153
154154 # get tensor w/ _trt
155155 # or... add constant for scalar primitive
156- if isinstance (t , float ) or isinstance (t , int ):
156+ if hasattr (t , "_trt" ) or isinstance (t , IntWrapper ):
157+ trt_tensor = t ._trt
158+ elif isinstance (t , float ) or isinstance (t , int ):
157159 shape = (1 ,)
158160 scalar = t * torch .ones (shape , dtype = dtype ).cpu ().numpy ()
159161 trt_tensor = network .add_constant (shape , scalar ).get_output (0 )
160- elif hasattr (t , "_trt" ):
161- trt_tensor = t ._trt
162162
163163 # or... add constant for leaf tensor w/o _trt
164164 else :
@@ -232,7 +232,7 @@ def trt_(network, *tensors):
232232 # GET TRT TENSOR (OR CREATE TRT CONSTANT)
233233
234234 # get tensor w/ _trt
235- if isinstance (t , torch .Tensor ) and hasattr (t , "_trt" ):
235+ if ( isinstance (t , torch .Tensor ) and hasattr (t , "_trt" )) or isinstance ( t , IntWrapper ):
236236 trt_tensor = t ._trt
237237
238238 # or... add constant for leaf tensor w/o _trt
You can’t perform that action at this time.
0 commit comments