@@ -671,7 +671,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int
671671 // data_gpu.copy_(data, false);
672672
673673 // Create output tensor
674- auto output = torch::zeros ({ num_blocks, 32 }, torch::dtype (target_dtype).device (device));
674+ auto output = torch::zeros ({ num_blocks, ele_per_blk }, torch::dtype (target_dtype).device (device));
675675
676676 switch (target_dtype) {
677677 case torch::kFloat16 :
@@ -705,7 +705,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int
705705 // data_gpu.copy_(data, false);
706706
707707 // Create output tensor
708- auto output = torch::zeros ({num_blocks, 256 }, torch::dtype (target_dtype).device (device));
708+ auto output = torch::zeros ({num_blocks, ele_per_blk }, torch::dtype (target_dtype).device (device));
709709
710710 switch (target_dtype) {
711711 case torch::kFloat16 :
@@ -736,7 +736,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int
736736 // data_gpu.copy_(data, false);
737737
738738 // Create output tensor
739- auto output = torch::zeros ({num_blocks, 256 }, torch::dtype (target_dtype).device (device));
739+ auto output = torch::zeros ({num_blocks, ele_per_blk }, torch::dtype (target_dtype).device (device));
740740
741741 switch (target_dtype) {
742742 case torch::kFloat16 :
@@ -768,7 +768,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int
768768 // data_gpu.copy_(data, false);
769769
770770 // Create output tensor
771- auto output = torch::zeros ({num_blocks, 256 }, torch::dtype (target_dtype).device (device));
771+ auto output = torch::zeros ({num_blocks, ele_per_blk }, torch::dtype (target_dtype).device (device));
772772
773773 switch (target_dtype) {
774774 case torch::kFloat16 :
@@ -799,7 +799,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int
799799 // data_gpu.copy_(data, false);
800800
801801 // Create output tensor
802- auto output = torch::zeros ({num_blocks, 256 }, torch::dtype (target_dtype).device (device));
802+ auto output = torch::zeros ({num_blocks, ele_per_blk }, torch::dtype (target_dtype).device (device));
803803
804804 switch (target_dtype) {
805805 case torch::kFloat16 :
@@ -830,7 +830,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int
830830 // data_gpu.copy_(data, false);
831831
832832 // Create output tensor
833- auto output = torch::zeros ({num_blocks, 256 }, torch::dtype (target_dtype).device (device));
833+ auto output = torch::zeros ({num_blocks, ele_per_blk }, torch::dtype (target_dtype).device (device));
834834
835835 switch (target_dtype) {
836836 case torch::kFloat16 :
@@ -861,7 +861,7 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
861861 // data_gpu.copy_(data, false);
862862
863863 // Create output tensor
864- auto output = torch::zeros ({num_blocks, 256 }, torch::dtype (target_dtype).device (device));
864+ auto output = torch::zeros ({num_blocks, ele_per_blk }, torch::dtype (target_dtype).device (device));
865865
866866 switch (target_dtype) {
867867 case torch::kFloat16 :
0 commit comments