Skip to content

Commit 721b071

Browse files
authored
Merge pull request #378 from NVIDIA/fix_multithreaded_fp16
fix(//core/conversion/converters/Weights): Fix buffer allocation for weights data
2 parents 9bd814d + 8dc3140 commit 721b071

File tree

1 file changed

+22
-2
lines changed

1 file changed

+22
-2
lines changed

core/conversion/converters/Weights.cpp

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,29 @@ Weights::Weights(ConversionCtx* ctx, at::Tensor t) {
8080

8181
// Store the data in the conversion context so it remains until building is
8282
// complete
83-
void* buf = malloc(t_cpu.numel() * sizeof(float));
83+
84+
void* buf = nullptr;
85+
86+
if (dtype_optional.value() == nvinfer1::DataType::kFLOAT) {
87+
buf = malloc(t_cpu.numel() * sizeof(float));
88+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(float));
89+
} else if (dtype_optional.value() == nvinfer1::DataType::kHALF) {
90+
buf = malloc(t_cpu.numel() * (sizeof(float) / 2));
91+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * (sizeof(float) / 2));
92+
} else if (dtype_optional.value() == nvinfer1::DataType::kINT8) {
93+
buf = malloc(t_cpu.numel() * sizeof(char));
94+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(char));
95+
} else if (dtype_optional.value() == nvinfer1::DataType::kINT32) {
96+
buf = malloc(t_cpu.numel() * sizeof(int));
97+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(int));
98+
} else if (dtype_optional.value() == nvinfer1::DataType::kBOOL) {
99+
buf = malloc(t_cpu.numel() * sizeof(bool));
100+
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(bool));
101+
} else {
102+
TRTORCH_THROW_ERROR("Found unsupported data type for tensor to weight conversion");
103+
}
104+
84105
ctx->builder_resources.push_back(buf);
85-
memcpy(buf, t_cpu.data_ptr(), t_cpu.numel() * sizeof(float));
86106

87107
this->data.type = dtype_optional.value();
88108
this->data.count = t_cpu.numel();

0 commit comments

Comments
 (0)