Skip to content

Commit dd7cfaf

Browse files
committed
fix(//core/conversion/converters/Weights): Fix buffer allocation for
weights data that occassionally may cause segfaults and causes issues with importing FP16 weights Signed-off-by: Naren Dasan <[email protected]> Signed-off-by: Naren Dasan <[email protected]>
1 parent 9bd814d commit dd7cfaf

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

core/conversion/converters/Weights.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,29 @@ Weights::Weights(ConversionCtx* ctx, at::Tensor t) {
8080

8181
// Store the data in the conversion context so it remains until building is
8282
// complete
83-
void* buf = malloc(t_cpu.numel() * sizeof(float));
83+
84+
void* buf;
85+
86+
if (dtype_optional.value() == nvinfer1::DataType::kFLOAT) {
87+
buf = malloc(t_cpu.numel() * sizeof(float));
88+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(float));
89+
} else if (dtype_optional.value() == nvinfer1::DataType::kHALF) {
90+
buf = malloc(t_cpu.numel() * (sizeof(float) / 2));
91+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * (sizeof(float) / 2));
92+
} else if (dtype_optional.value() == nvinfer1::DataType::kINT8) {
93+
buf = malloc(t_cpu.numel() * sizeof(char));
94+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(char));
95+
} else if (dtype_optional.value() == nvinfer1::DataType::kINT32) {
96+
buf = malloc(t_cpu.numel() * sizeof(int));
97+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(int));
98+
} else if (dtype_optional.value() == nvinfer1::DataType::kBOOL) {
99+
buf = malloc(t_cpu.numel() * sizeof(bool));
100+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(bool));
101+
} else {
102+
TRTORCH_THROW_ERROR("Found unsupported data type for tensor to weight conversion");
103+
}
104+
84105
ctx->builder_resources.push_back(buf);
85-
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(float));
86106

87107
this->data.type = dtype_optional.value();
88108
this->data.count = t_cpu.numel();

0 commit comments

Comments
 (0)