@@ -89,7 +89,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
89
89
if (isIValue ()) {
90
90
LOG_DEBUG (ctx->logger , " Found IValue containing object of type " << *(ptr_.ivalue ->type ()));
91
91
}
92
-
92
+
93
93
TRTORCH_CHECK (
94
94
isITensor () || (isIValue () && (ptr_.ivalue ->isTensor () || ptr_.ivalue ->isCustomClass ())),
95
95
" Requested either IValue containing a Tensor, or ITensor, however Var type is " << type_name ());
@@ -100,8 +100,10 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
100
100
if (ptr_.ivalue ->isTensor ()) {
101
101
auto weights = converters::Weights ();
102
102
auto tensor = ptr_.ivalue ->toTensor ();
103
- if ((tensor.scalar_type () == at::kLong || tensor.scalar_type () == at::kDouble ) && !ctx->settings .truncate_long_and_double ) {
104
- TRTORCH_THROW_ERROR (" Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled" );
103
+ if ((tensor.scalar_type () == at::kLong || tensor.scalar_type () == at::kDouble ) &&
104
+ !ctx->settings .truncate_long_and_double ) {
105
+ TRTORCH_THROW_ERROR (
106
+ " Unable to freeze tensor of type Int64/Float64 into constant layer, try to compile model with truncate_long_and_double enabled" );
105
107
} else if (tensor.scalar_type () == at::kLong && ctx->settings .truncate_long_and_double ) {
106
108
weights = converters::Weights (ctx, tensor.toType (at::kInt ));
107
109
LOG_WARNING (" Truncating weight (constant in the graph) from Int64 to Int32" );
@@ -111,7 +113,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) {
111
113
} else {
112
114
weights = converters::Weights (ctx, tensor);
113
115
}
114
-
116
+
115
117
auto const_layer = ctx->net ->addConstant (weights.shape , weights.data );
116
118
TRTORCH_CHECK (const_layer, " Unable to freeze tensor into constant layer" );
117
119
out = const_layer->getOutput (0 );
0 commit comments