Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions kt-kernel/cuda/custom_gguf/dequant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -705,7 +705,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -736,7 +736,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -768,7 +768,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -799,7 +799,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -830,7 +830,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -861,7 +861,7 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down
14 changes: 7 additions & 7 deletions kt-sft/csrc/ktransformers_ext/cuda/custom_gguf/dequant.cu
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({ num_blocks, 32 }, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({ num_blocks, ele_per_blk }, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -705,7 +705,7 @@ torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -736,7 +736,7 @@ torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -768,7 +768,7 @@ torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -799,7 +799,7 @@ torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -830,7 +830,7 @@ torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down Expand Up @@ -861,7 +861,7 @@ torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const i
//data_gpu.copy_(data, false);

// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(target_dtype).device(device));
auto output = torch::zeros({num_blocks, ele_per_blk}, torch::dtype(target_dtype).device(device));

switch (target_dtype) {
case torch::kFloat16:
Expand Down