Skip to content

Commit cb5c06b

Browse files
committed
refactor: refactor the RecordNewTensor function
Signed-off-by: Bo Wang <[email protected]>
1 parent c169a81 commit cb5c06b

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,11 +130,8 @@ ConversionCtx::~ConversionCtx() {
130130
}
131131

132132
nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
133-
if (!RecordNewTensor(value, tensor)) {
134-
LOG_WARNING(
135-
"Trying to rewrite the name " << value->debugName() << " to a named ITensor " << tensor->getName() << ".");
136-
}
137-
tensor->setName(value->debugName().c_str());
133+
RecordNewTensor(value, tensor);
134+
138135
return tensor;
139136
}
140137

@@ -143,10 +140,13 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
143140
return &this->evaluated_value_map[value];
144141
}
145142

146-
bool ConversionCtx::RecordNewTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
143+
void ConversionCtx::RecordNewTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
147144
value_tensor_map[value] = tensor;
148145
auto ret = known_tensors.insert(tensor);
149-
return ret.second;
146+
if (!ret) {
147+
LOG_WARNING(
148+
"Trying to record the value " << value->debugName() << " with the ITensor " << tensor->getName() << " again.");
149+
}
150150
}
151151

152152
std::string ConversionCtx::SerializeEngine() {

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ struct ConversionCtx {
4646
ConversionCtx(BuilderSettings settings);
4747
std::string SerializeEngine();
4848
nvinfer1::ITensor* AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
49-
bool RecordNewTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
49+
void RecordNewTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
5050
torch::jit::IValue* AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue tensor);
5151
bool CheckLayerAddition(const torch::jit::Node* n);
5252

0 commit comments

Comments
 (0)