@@ -129,24 +129,24 @@ nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor
129
129
}
130
130
131
131
nvinfer1::ITensor* castITensor (ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
132
- // No matter whether tensor->getType() == dtype, identity layer is always needed.
133
- // Otherwise will change the input tensor name in aten::to converter by AssociateValueAndTensor function
134
- // When the input of aten::to is network input, will cause error
135
- std::ostringstream tensor_id;
136
- tensor_id << reinterpret_cast <int *>(tensor);
137
-
138
- auto id_layer = ctx->net ->addIdentity (*tensor);
139
- TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
140
- auto casted_tensor = id_layer->getOutput (0 );
141
- casted_tensor->setType (dtype);
132
+ if (tensor->getType () != dtype) {
133
+ std::ostringstream tensor_id;
134
+ tensor_id << reinterpret_cast <int *>(tensor);
142
135
143
- LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
136
+ auto id_layer = ctx->net ->addIdentity (*tensor);
137
+ TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
138
+ auto casted_tensor = id_layer->getOutput (0 );
139
+ casted_tensor->setType (dtype);
144
140
145
- std::stringstream ss;
146
- ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
147
- id_layer->setName (ss.str ().c_str ());
148
- return casted_tensor;
141
+ LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
149
142
143
+ std::stringstream ss;
144
+ ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
145
+ id_layer->setName (ss.str ().c_str ());
146
+ return casted_tensor;
147
+ } else {
148
+ return tensor;
149
+ }
150
150
}
151
151
152
152
nvinfer1::ITensor* tensor_to_const (ConversionCtx* ctx, at::Tensor t, const std::string& name) {
0 commit comments