@@ -129,24 +129,24 @@ nvinfer1::ITensor* applyIdentityOp(ConversionCtx* ctx, nvinfer1::ITensor* tensor
129
129
}
130
130
131
131
nvinfer1::ITensor* castITensor (ConversionCtx* ctx, nvinfer1::ITensor* tensor, nvinfer1::DataType dtype) {
132
- if (tensor->getType () != dtype) {
133
- std::ostringstream tensor_id;
134
- tensor_id << reinterpret_cast <int *>(tensor);
132
+ // No matter whether tensor->getType() == dtype, identity layer is always needed.
133
+ // Otherwise will change the input tensor name in aten::to converter by AssociateValueAndTensor function
134
+ // When the input of aten::to is network input, will cause error
135
+ std::ostringstream tensor_id;
136
+ tensor_id << reinterpret_cast <int *>(tensor);
135
137
136
- auto id_layer = ctx->net ->addIdentity (*tensor);
137
- TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
138
- auto casted_tensor = id_layer->getOutput (0 );
139
- casted_tensor->setType (dtype);
138
+ auto id_layer = ctx->net ->addIdentity (*tensor);
139
+ TORCHTRT_CHECK (id_layer, " Unable to create identity layer for ITensor: " << tensor_id.str ());
140
+ auto casted_tensor = id_layer->getOutput (0 );
141
+ casted_tensor->setType (dtype);
140
142
141
- LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
143
+ LOG_DEBUG (ctx->logger , " Casting ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype);
144
+
145
+ std::stringstream ss;
146
+ ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
147
+ id_layer->setName (ss.str ().c_str ());
148
+ return casted_tensor;
142
149
143
- std::stringstream ss;
144
- ss << " [Cast ITensor " << tensor_id.str () << " from " << tensor->getType () << " to " << dtype << " ]" ;
145
- id_layer->setName (ss.str ().c_str ());
146
- return casted_tensor;
147
- } else {
148
- return tensor;
149
- }
150
150
}
151
151
152
152
nvinfer1::ITensor* tensor_to_const (ConversionCtx* ctx, at::Tensor t, const std::string& name) {
0 commit comments