@@ -80,9 +80,29 @@ Weights::Weights(ConversionCtx* ctx, at::Tensor t) {
80
80
81
81
// Store the data in the conversion context so it remains until building is
82
82
// complete
83
- void * buf = malloc (t_cpu.numel () * sizeof (float ));
83
+
84
+ void * buf = nullptr ;
85
+
86
+ if (dtype_optional.value () == nvinfer1::DataType::kFLOAT ) {
87
+ buf = malloc (t_cpu.numel () * sizeof (float ));
88
+ memcpy (buf, t_cpu.data_ptr (), t_cpu.numel () * sizeof (float ));
89
+ } else if (dtype_optional.value () == nvinfer1::DataType::kHALF ) {
90
+ buf = malloc (t_cpu.numel () * (sizeof (float ) / 2 ));
91
+ memcpy (buf, t_cpu.data_ptr (), t_cpu.numel () * (sizeof (float ) / 2 ));
92
+ } else if (dtype_optional.value () == nvinfer1::DataType::kINT8 ) {
93
+ buf = malloc (t_cpu.numel () * sizeof (char ));
94
+ memcpy (buf, t_cpu.data_ptr (), t_cpu.numel () * sizeof (char ));
95
+ } else if (dtype_optional.value () == nvinfer1::DataType::kINT32 ) {
96
+ buf = malloc (t_cpu.numel () * sizeof (int ));
97
+ memcpy (buf, t_cpu.data_ptr (), t_cpu.numel () * sizeof (int ));
98
+ } else if (dtype_optional.value () == nvinfer1::DataType::kBOOL ) {
99
+ buf = malloc (t_cpu.numel () * sizeof (bool ));
100
+ memcpy (buf, t_cpu.data_ptr (), t_cpu.numel () * sizeof (bool ));
101
+ } else {
102
+ TRTORCH_THROW_ERROR (" Found unsupported data type for tensor to weight conversion" );
103
+ }
104
+
84
105
ctx->builder_resources .push_back (buf);
85
- memcpy (buf, t_cpu.data_ptr (), t_cpu.numel () * sizeof (float ));
86
106
87
107
this ->data .type = dtype_optional.value ();
88
108
this ->data .count = t_cpu.numel ();
0 commit comments