Skip to content

Commit b0f827d

Browse files
authored
[chore](cuda): explicitly use ele_per_blk var for better readability (#1784)
1 parent 779bf14 commit b0f827d

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

kt-kernel/cuda/custom_gguf/dequant.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int
671671
//data_gpu.copy_(data, false);
672672

673673
// Create output tensor
674-
auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device));
674+
auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device));
675675

676676
switch (target_dtype) {
677677
case torch::kFloat16:
@@ -705,7 +705,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int
705705
//data_gpu.copy_(data, false);
706706

707707
// Create output tensor
708-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
708+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
709709

710710
switch (target_dtype) {
711711
case torch::kFloat16:
@@ -736,7 +736,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int
736736
//data_gpu.copy_(data, false);
737737

738738
// Create output tensor
739-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
739+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
740740

741741
switch (target_dtype) {
742742
case torch::kFloat16:
@@ -768,7 +768,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int
768768
//data_gpu.copy_(data, false);
769769

770770
// Create output tensor
771-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
771+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
772772

773773
switch (target_dtype) {
774774
case torch::kFloat16:
@@ -799,7 +799,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int
799799
//data_gpu.copy_(data, false);
800800

801801
// Create output tensor
802-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
802+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
803803

804804
switch (target_dtype) {
805805
case torch::kFloat16:
@@ -830,7 +830,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int
830830
//data_gpu.copy_(data, false);
831831

832832
// Create output tensor
833-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
833+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
834834

835835
switch (target_dtype) {
836836
case torch::kFloat16:
@@ -861,7 +861,7 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
861861
//data_gpu.copy_(data, false);
862862

863863
// Create output tensor
864-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
864+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
865865

866866
switch (target_dtype) {
867867
case torch::kFloat16:

kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -671,7 +671,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int
671671
//data_gpu.copy_(data, false);
672672

673673
// Create output tensor
674-
auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device));
674+
auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device));
675675

676676
switch (target_dtype) {
677677
case torch::kFloat16:
@@ -705,7 +705,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int
705705
//data_gpu.copy_(data, false);
706706

707707
// Create output tensor
708-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
708+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
709709

710710
switch (target_dtype) {
711711
case torch::kFloat16:
@@ -736,7 +736,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int
736736
//data_gpu.copy_(data, false);
737737

738738
// Create output tensor
739-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
739+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
740740

741741
switch (target_dtype) {
742742
case torch::kFloat16:
@@ -768,7 +768,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int
768768
//data_gpu.copy_(data, false);
769769

770770
// Create output tensor
771-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
771+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
772772

773773
switch (target_dtype) {
774774
case torch::kFloat16:
@@ -799,7 +799,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int
799799
//data_gpu.copy_(data, false);
800800

801801
// Create output tensor
802-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
802+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
803803

804804
switch (target_dtype) {
805805
case torch::kFloat16:
@@ -830,7 +830,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int
830830
//data_gpu.copy_(data, false);
831831

832832
// Create output tensor
833-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
833+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
834834

835835
switch (target_dtype) {
836836
case torch::kFloat16:
@@ -861,7 +861,7 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
861861
//data_gpu.copy_(data, false);
862862

863863
// Create output tensor
864-
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
864+
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));
865865

866866
switch (target_dtype) {
867867
case torch::kFloat16:

0 commit comments

Comments
 (0)