Skip to content

Commit a901b0c

Browse files
committed
fix: add a warning msg when renaming itensors
Signed-off-by: Bo Wang <[email protected]>
1 parent 1a22204 commit a901b0c

File tree

3 files changed

+15
-2
lines changed

3 files changed

+15
-2
lines changed

core/conversion/conversion.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ void AddInputs(
188188
ctx->input_is_dynamic = true;
189189
}
190190

191-
ctx->value_tensor_map[in] = trt_in;
191+
ctx->AddNamedTensor(in, trt_in);
192192
ctx->num_inputs += 1;
193193
}
194194

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,11 @@ ConversionCtx::~ConversionCtx() {
130130
}
131131

132132
nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
133+
if (!AddNamedTensor(value, tensor)) {
134+
LOG_WARNING(
135+
"Trying to rewrite the name " << value->debugName() << " to a named ITensor " << tensor->getName() << ".");
136+
}
133137
tensor->setName(value->debugName().c_str());
134-
this->value_tensor_map[value] = tensor;
135138
return tensor;
136139
}
137140

@@ -140,6 +143,12 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
140143
return &this->evaluated_value_map[value];
141144
}
142145

146+
bool ConversionCtx::AddNamedTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
147+
value_tensor_map[value] = tensor;
148+
auto ret = named_tensors.insert(tensor);
149+
return ret.second;
150+
}
151+
143152
std::string ConversionCtx::SerializeEngine() {
144153
#if NV_TENSORRT_MAJOR > 7
145154
auto serialized_network = builder->buildSerializedNetwork(*net, *cfg);

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ struct ConversionCtx {
4646
ConversionCtx(BuilderSettings settings);
4747
std::string SerializeEngine();
4848
nvinfer1::ITensor* AssociateValueAndTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
49+
bool AddNamedTensor(const torch::jit::Value* value, nvinfer1::ITensor* tensor);
4950
torch::jit::IValue* AssociateValueAndIValue(const torch::jit::Value* value, torch::jit::IValue tensor);
5051
bool CheckLayerAddition(const torch::jit::Node* n);
5152

@@ -69,6 +70,9 @@ struct ConversionCtx {
6970

7071
std::unordered_map<const torch::jit::Value*, nvinfer1::ITensor*> value_tensor_map;
7172
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> evaluated_value_map;
73+
74+
// record already named ITensors to prevent rewriting another name to the same tensor
75+
std::unordered_set<nvinfer1::ITensor*> named_tensors;
7276
};
7377

7478
} // namespace conversion

0 commit comments

Comments
 (0)