File tree Expand file tree Collapse file tree 3 files changed +15
-2
lines changed Expand file tree Collapse file tree 3 files changed +15
-2
lines changed Original file line number Diff line number Diff line change @@ -188,7 +188,7 @@ void AddInputs(
188
188
ctx->input_is_dynamic = true ;
189
189
}
190
190
191
- ctx->value_tensor_map [in] = trt_in;
191
+ ctx->AddNamedTensor (in, trt_in) ;
192
192
ctx->num_inputs += 1 ;
193
193
}
194
194
Original file line number Diff line number Diff line change @@ -130,8 +130,11 @@ ConversionCtx::~ConversionCtx() {
130
130
}
131
131
132
132
nvinfer1::ITensor* ConversionCtx::AssociateValueAndTensor (const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
133
+ if (!AddNamedTensor (value, tensor)) {
134
+ LOG_WARNING (
135
+ " Trying to rewrite the name " << value->debugName () << " to a named ITensor " << tensor->getName () << " ." );
136
+ }
133
137
tensor->setName (value->debugName ().c_str ());
134
- this ->value_tensor_map [value] = tensor;
135
138
return tensor;
136
139
}
137
140
@@ -140,6 +143,12 @@ torch::jit::IValue* ConversionCtx::AssociateValueAndIValue(const torch::jit::Val
140
143
return &this ->evaluated_value_map [value];
141
144
}
142
145
146
+ bool ConversionCtx::AddNamedTensor (const torch::jit::Value* value, nvinfer1::ITensor* tensor) {
147
+ value_tensor_map[value] = tensor;
148
+ auto ret = named_tensors.insert (tensor);
149
+ return ret.second ;
150
+ }
151
+
143
152
std::string ConversionCtx::SerializeEngine () {
144
153
#if NV_TENSORRT_MAJOR > 7
145
154
auto serialized_network = builder->buildSerializedNetwork (*net, *cfg);
Original file line number Diff line number Diff line change @@ -46,6 +46,7 @@ struct ConversionCtx {
46
46
ConversionCtx (BuilderSettings settings);
47
47
std::string SerializeEngine ();
48
48
nvinfer1::ITensor* AssociateValueAndTensor (const torch::jit::Value* value, nvinfer1::ITensor* tensor);
49
+ bool AddNamedTensor (const torch::jit::Value* value, nvinfer1::ITensor* tensor);
49
50
torch::jit::IValue* AssociateValueAndIValue (const torch::jit::Value* value, torch::jit::IValue tensor);
50
51
bool CheckLayerAddition (const torch::jit::Node* n);
51
52
@@ -69,6 +70,9 @@ struct ConversionCtx {
69
70
70
71
std::unordered_map<const torch::jit::Value*, nvinfer1::ITensor*> value_tensor_map;
71
72
std::unordered_map<const torch::jit::Value*, torch::jit::IValue> evaluated_value_map;
73
+
74
+ // record already named ITensors to prevent rewriting another name to the same tensor
75
+ std::unordered_set<nvinfer1::ITensor*> named_tensors;
72
76
};
73
77
74
78
} // namespace conversion
You can’t perform that action at this time.
0 commit comments