Skip to content

Commit e55e95e

Browse files
Smarterquant tests fixes (#7)
* Checkpoint: Refactor SmarterQuantTensorInfo and add headers - Created C-compatible SmarterQuantTensorInfo in ggml-smarterquant-types.h - Updated ggml.h, ggml-cpu.c, llama-quant.h, llama-quant.cpp, llama-model-loader.cpp, and llama-model.cpp to use the new struct. - Added missing C++ headers and forward declarations to llama-quant.cpp in an attempt to resolve compilation errors. Note: Codebase is not currently compiling due to issues in llama-quant.cpp and an incorrect CMake build path used in the last attempt. User will address compilation issues next. * Fix compilation issues and implement SmarterQuant stubs - Resolved various compilation errors in llama-quant.cpp related to includes, function definitions, and SmarterQuant logic. - Implemented parsing for SmarterQuant JSON configuration in `load_smarter_quant_config`. - Added a basic serial implementation for `llama_tensor_quantize_smarter_blocks`. - Provided functional stubs for quantization helper functions within `llama-quant.cpp`. - Ensured the public `llama_model_quantize` API correctly calls the implementation in `llama-quant.cpp`. - Fixed a memory leak by adding a destructor to `llama_model` to free SmarterQuant permutation data. - Verified that `ggml-cpu.c` and `llama-model.cpp` changes for SmarterQuant dequantization compile. - The main library and all example tools now compile and link successfully. * feat: Implement SmarterQuant numerical correctness tests This commit introduces a new test suite for the SmarterQuant functionality to verify the numerical correctness of the custom block quantization and dequantization logic. Key changes: - Added `tests/test-smarterquant.cpp` with a test case that: - Uses a sample F32 tensor with mixed quantization types (Q4_0, Q5_1, Q8_0, Q2_K). - Applies column permutation. - Quantizes using `llama_tensor_quantize_smarter_blocks`. - Dequantizes using `ggml_get_rows_smarterquant`. - Verifies the output against the original data. - Updated `tests/CMakeLists.txt` to build the new test. - Made `llama_tensor_quantize_smarter_blocks` in `src/llama-quant.cpp` non-static and added its declaration to `src/llama-quant.h`. - Made `ggml_get_rows_smarterquant` in `ggml/src/ggml-cpu/ggml-cpu.c` non-static to allow direct testing. - The implemented test passes, confirming the core CPU implementation of SmarterQuant (Tasks 1 and 2 from todo.txt) is working as expected for the tested scenario. * feat: Implement SmarterQuant numerical correctness tests and update todo This commit introduces a new test suite for the SmarterQuant functionality to verify the numerical correctness of the custom block quantization and dequantization logic. It also updates todo.txt to reflect this progress. Key changes: - Added `tests/test-smarterquant.cpp` with a test case that: - Uses a sample F32 tensor with mixed quantization types (Q4_0, Q5_1, Q8_0, Q2_K). - Applies column permutation. - Quantizes using `llama_tensor_quantize_smarter_blocks`. - Dequantizes using `ggml_get_rows_smarterquant`. - Verifies the numerical output against the original F32 data. - Updated `tests/CMakeLists.txt` to build the new test. - Made `llama_tensor_quantize_smarter_blocks` in `src/llama-quant.cpp` non-static and added its declaration to `src/llama-quant.h`. - Made `ggml_get_rows_smarterquant` in `ggml/src/ggml-cpu/ggml-cpu.c` non-static to allow direct testing by the new test suite. - The implemented test passes, confirming the core CPU implementation of SmarterQuant (Tasks 1 and 2 from todo.txt) is working as expected for the tested scenario. - Updated `todo.txt` to mark the CPU numerical correctness testing as DONE and outline further potential test enhancements. --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
1 parent 6de9577 commit e55e95e

File tree

14 files changed

+1115
-1006
lines changed

14 files changed

+1115
-1006
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <stdint.h>
4+
#include <stdbool.h>
5+
// Forward declare ggml_type if it's not pulled in by stdint/stdbool,
6+
// though it's better if this file can be self-contained for basic types
7+
// or include a minimal ggml_core_types.h if one existed.
8+
// For now, assuming ggml_type will be known by consumers including this after ggml.h,
9+
// or we might need to include "ggml_core.h" or similar if such a thing exists
10+
// that defines ggml_type without pulling all of ggml.h.
11+
// Given its usage in ggml_tensor, it should be fine.
12+
13+
// C-compatible structure for SmarterQuant tensor information
14+
struct SmarterQuantTensorInfo {
15+
// Specifies the ggml_type (as int8_t for storage, cast to enum ggml_type for use)
16+
// for each of the first four 256-column-wide blocks of the tensor.
17+
// Subsequent blocks will use the type specified at index 3.
18+
int8_t compression_types[4];
19+
20+
// Defines how columns of the original tensor should be reordered.
21+
// Points to an array of column indices.
22+
// The element at new_data[col_idx_new] comes from original_data[column_permutation[col_idx_new]].
23+
// This memory must be managed externally (e.g., by the code loading the configuration).
24+
int32_t * column_permutation; // Using int32_t as column indices are usually within this range
25+
int64_t n_cols_for_permutation; // Number of elements in column_permutation array, should match tensor's ne[0]
26+
27+
// Flag indicating if SmarterQuant is enabled for this tensor.
28+
bool enabled;
29+
};

ggml/include/ggml.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,12 @@ extern "C" {
347347
struct ggml_context;
348348
struct ggml_cgraph;
349349

350+
// Forward declare SmarterQuantTensorInfo
351+
// Actual definition is in llama-quant.h, which is not included here to keep ggml.h independent.
352+
// ggml_tensor will store a void pointer to be cast to SmarterQuantTensorInfo * when needed.
353+
// #include "llama-quant.h" // No longer needed here, definition moved or forward declared
354+
#include "ggml-smarterquant-types.h" // Contains definition for SmarterQuantTensorInfo
355+
350356
// NOTE: always add types at the end of the enum to keep backward compatibility
351357
enum ggml_type {
352358
GGML_TYPE_F32 = 0,
@@ -605,8 +611,9 @@ extern "C" {
605611
char name[GGML_MAX_NAME];
606612

607613
void * extra; // extra things e.g. for ggml-cuda.cu
614+
struct SmarterQuantTensorInfo * sq_info; // For SmarterQuant per-block quantization info
608615

609-
char padding[8];
616+
char padding[16]; // Adjusted padding for alignment
610617
};
611618

612619
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 129 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,19 @@
44
#include "ggml-backend-impl.h"
55
#include "ggml-backend.h"
66
#include "ggml-cpu-traits.h"
7-
#include "ggml-cpu-impl.h"
8-
#include "ggml-cpu.h"
9-
#include "ggml-impl.h"
10-
#include "ggml-cpu-quants.h"
11-
#include "ggml-threading.h"
12-
#include "ggml.h"
7+
#include "ggml-backend-impl.h" // Keep this
8+
#include "ggml-backend.h" // Keep this
9+
#include "ggml-cpu-traits.h" // Keep this
10+
#include "ggml-cpu-impl.h" // Keep this
11+
#include "ggml-cpu.h" // Keep this
12+
#include "ggml-impl.h" // Keep this
13+
#include "ggml-cpu-quants.h" // Keep this
14+
#include "ggml-threading.h" // Keep this
15+
#include "ggml.h" // Keep this, it will now include ggml-smarterquant-types.h
16+
// No longer need to include llama-quant.h or ../llama-quant.h here
17+
18+
// Forward declaration for SmarterQuant dequantization function
19+
void ggml_get_rows_smarterquant(const struct ggml_tensor * tensor, const char * src_row_base, float * dst_row_final_target);
1320

1421
#if defined(_MSC_VER) || defined(__MINGW32__)
1522
#include <malloc.h> // using malloc.h with MSC/MINGW
@@ -9726,6 +9733,9 @@ static void ggml_compute_forward_transpose(
97269733

97279734
// ggml_compute_forward_get_rows
97289735

9736+
// This is the older, likely problematic definition that will be removed by the other change.
9737+
// The redefinition error indicates a duplicate. The one at line 13243 is the one we want to keep and fix.
9738+
97299739
static void ggml_compute_forward_get_rows_q(
97309740
const struct ggml_compute_params * params,
97319741
struct ggml_tensor * dst) {
@@ -9742,7 +9752,9 @@ static void ggml_compute_forward_get_rows_q(
97429752
ggml_to_float_t const dequantize_row_q = ggml_get_type_traits(type)->to_float;
97439753

97449754
assert(ne0 == nc);
9745-
assert(ne02 == ne11);
9755+
assert(ne02 == ne11); // Batch/head dimensions must match or be broadcastable if that's intended.
9756+
// For SmarterQuant, src0->type can be different from dst->type (which is F32)
9757+
// The original assertion nb00 == ggml_type_size(type) is fine as it refers to src0.
97469758
assert(nb00 == ggml_type_size(type));
97479759
assert(ggml_nrows(dst) == nr);
97489760

@@ -9757,16 +9769,23 @@ static void ggml_compute_forward_get_rows_q(
97579769
const int ir1 = MIN(ir0 + dr, nr);
97589770

97599771
for (int64_t i = ir0; i < ir1; ++i) {
9760-
const int64_t i12 = i/(ne11*ne10);
9761-
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
9762-
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
9763-
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
9772+
const int64_t i12 = i/(ne11*ne10); // dst batch index
9773+
const int64_t i11 = (i - i12*ne11*ne10)/ne10; // dst head index / dst row index within batch
9774+
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); // dst element index within row / this is the r in dst_data + r * ne0
9775+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); // row index in src0
97649776

9765-
GGML_ASSERT(i01 >= 0 && i01 < ne01);
9777+
GGML_ASSERT(i01 >= 0 && i01 < ne01); // Ensure index is valid for src0's rows
97669778

9767-
dequantize_row_q(
9768-
(const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
9769-
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
9779+
float * const dst_current_row_ptr = (float *)((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3);
9780+
const char * const src_row_ptr = (char *)src0->data + i01*nb01 + i11*nb02 + i12*nb03; // Assuming src0's higher dim strides match dst's for batch/head
9781+
9782+
if (src0->sq_info != NULL && src0->sq_info->enabled) {
9783+
// SmarterQuant path: dst is already F32, so ggml_get_rows_smarterquant works directly.
9784+
ggml_get_rows_smarterquant(src0, src_row_ptr, dst_current_row_ptr);
9785+
} else {
9786+
// Original quantized path
9787+
dequantize_row_q(src_row_ptr, dst_current_row_ptr, nc);
9788+
}
97709789
}
97719790
}
97729791

@@ -9865,8 +9884,13 @@ static void ggml_compute_forward_get_rows_f32(
98659884
const int64_t nr = ggml_nelements(src1);
98669885

98679886
assert(ne0 == nc);
9868-
assert(ne02 == ne11);
9869-
assert(nb00 == sizeof(float));
9887+
assert(ne02 == ne11); // Batch/head dimensions must match or be broadcastable.
9888+
// For SmarterQuant, src0->type might be a quantized type, but dst is F32.
9889+
// The original nb00 assertion is for the original F32 path.
9890+
// If SmarterQuant is active, src0->type is not necessarily F32.
9891+
if (!(src0->sq_info != NULL && src0->sq_info->enabled)) {
9892+
GGML_ASSERT(nb00 == sizeof(float)); // Original assertion for non-SmarterQuant F32 path
9893+
}
98709894
assert(ggml_nrows(dst) == nr);
98719895

98729896
const int ith = params->ith;
@@ -9880,16 +9904,23 @@ static void ggml_compute_forward_get_rows_f32(
98809904
const int ir1 = MIN(ir0 + dr, nr);
98819905

98829906
for (int64_t i = ir0; i < ir1; ++i) {
9883-
const int64_t i12 = i/(ne11*ne10);
9884-
const int64_t i11 = (i - i12*ne11*ne10)/ne10;
9885-
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10);
9886-
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
9907+
const int64_t i12 = i/(ne11*ne10); // dst batch index
9908+
const int64_t i11 = (i - i12*ne11*ne10)/ne10; // dst head index / dst row index within batch
9909+
const int64_t i10 = (i - i12*ne11*ne10 - i11*ne10); // dst element index within row / this is the r in dst_data + r * ne0
9910+
const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12); // row index in src0
98879911

9888-
GGML_ASSERT(i01 >= 0 && i01 < ne01);
9912+
GGML_ASSERT(i01 >= 0 && i01 < ne01); // Ensure index is valid for src0's rows
98899913

9890-
ggml_vec_cpy_f32(nc,
9891-
(float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
9892-
(float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
9914+
float * const dst_current_row_ptr = (float *)((char *)dst->data + i10*nb1 + i11*nb2 + i12*nb3);
9915+
const char * const src_row_ptr = (char *)src0->data + i01*nb01 + i11*nb02 + i12*nb03; // nb02 and nb03 should align with dst's nb2, nb3 for batch/head strides
9916+
9917+
if (src0->sq_info != NULL && src0->sq_info->enabled) {
9918+
ggml_get_rows_smarterquant(src0, src_row_ptr, dst_current_row_ptr);
9919+
} else {
9920+
// Original logic for non-SmarterQuant or when sq_info is null/disabled
9921+
GGML_ASSERT(src0->type == GGML_TYPE_F32 && "Original path expects F32 src for f32 dst in get_rows");
9922+
ggml_vec_cpy_f32(nc, dst_current_row_ptr, (const float *)src_row_ptr);
9923+
}
98939924
}
98949925
}
98959926

@@ -9899,6 +9930,15 @@ static void ggml_compute_forward_get_rows(
98999930

99009931
const struct ggml_tensor * src0 = dst->src[0];
99019932

9933+
// If SmarterQuant is active for src0 and dst is F32, use the SmarterQuant path via ggml_compute_forward_get_rows_f32.
9934+
// This handles cases where src0 itself might be F32 but has sq_info (though less common for F32)
9935+
// or if src0 is quantized and has sq_info.
9936+
if (dst->type == GGML_TYPE_F32 && src0->sq_info != NULL && src0->sq_info->enabled) {
9937+
ggml_compute_forward_get_rows_f32(params, dst);
9938+
return;
9939+
}
9940+
9941+
// Original dispatch logic for non-SmarterQuant cases or when dst is not F32
99029942
switch (src0->type) {
99039943
case GGML_TYPE_Q4_0:
99049944
case GGML_TYPE_Q4_1:
@@ -9923,6 +9963,8 @@ static void ggml_compute_forward_get_rows(
99239963
case GGML_TYPE_IQ3_S:
99249964
case GGML_TYPE_IQ2_S:
99259965
{
9966+
// If we reach here, it means dst is not F32 or sq_info is not enabled,
9967+
// so use the standard quantized path.
99269968
ggml_compute_forward_get_rows_q(params, dst);
99279969
} break;
99289970
case GGML_TYPE_F16:
@@ -9933,7 +9975,7 @@ static void ggml_compute_forward_get_rows(
99339975
{
99349976
ggml_compute_forward_get_rows_bf16(params, dst);
99359977
} break;
9936-
case GGML_TYPE_F32:
9978+
case GGML_TYPE_F32: // This case now only handles non-SmarterQuant F32 src or when dst is not F32.
99379979
case GGML_TYPE_I32:
99389980
{
99399981
ggml_compute_forward_get_rows_f32(params, dst);
@@ -13150,6 +13192,67 @@ static void ggml_compute_forward_unary(
1315013192

1315113193
// ggml_compute_forward_get_rel_pos
1315213194

13195+
// SmarterQuant: Custom dequantization and unpermutation for a single row
13196+
// Note: This function assumes src_row_base points to the beginning of the *specific row* being processed.
13197+
// It also assumes that tensor->sq_info and tensor->sq_info->column_permutation are valid.
13198+
void ggml_get_rows_smarterquant(const struct ggml_tensor * tensor, const char * src_row_base, float * dst_row_final_target) {
13199+
const int64_t ne0 = tensor->ne[0]; // Number of elements in the row (columns)
13200+
13201+
// Allocate temporary buffer for the dequantized but still permuted row on the stack
13202+
float * dequantized_permuted_row = (float *)alloca(ne0 * sizeof(float));
13203+
if (!dequantized_permuted_row) {
13204+
// This should ideally not happen for reasonable ne0 sizes with alloca.
13205+
GGML_ABORT("alloca failed for dequantized_permuted_row in SmarterQuant");
13206+
}
13207+
13208+
size_t current_segment_src_offset = 0; // Byte offset within the current row's data
13209+
for (int64_t j = 0; j < ne0; j += 256) { // Iterate through 256-element segments
13210+
const int64_t current_block_ne = MIN(256, ne0 - j);
13211+
const int block_idx_in_row = j / 256;
13212+
enum ggml_type segment_type;
13213+
13214+
// Determine the quantization type for the current segment
13215+
if (block_idx_in_row < 4) {
13216+
segment_type = (enum ggml_type)tensor->sq_info->compression_types[block_idx_in_row];
13217+
} else {
13218+
segment_type = (enum ggml_type)tensor->sq_info->compression_types[3];
13219+
}
13220+
13221+
const struct ggml_type_traits * current_qfns = ggml_get_type_traits(segment_type);
13222+
if (current_qfns->to_float == NULL) {
13223+
GGML_LOG_ERROR("missing to_float for type %s (segment %lld, block_idx %d)\n", ggml_type_name(segment_type), (long long)j, block_idx_in_row);
13224+
GGML_ABORT("Unsupported SmarterQuant segment type");
13225+
}
13226+
13227+
if (current_block_ne % current_qfns->blck_size != 0) {
13228+
GGML_LOG_ERROR("SmarterQuant segment ne %lld not divisible by blck_size %lld for type %s\n", (long long)current_block_ne, (long long)current_qfns->blck_size, ggml_type_name(segment_type));
13229+
GGML_ABORT("SmarterQuant segment size error");
13230+
}
13231+
13232+
current_qfns->to_float(src_row_base + current_segment_src_offset,
13233+
dequantized_permuted_row + j,
13234+
current_block_ne);
13235+
13236+
// DEBUG PRINT
13237+
// printf("DEBUG Dequant: row_seg %lld, type %s, elements %lld, first val: %f, src_offset %zu\n", (long long)j/256, ggml_type_name(segment_type), (long long)current_block_ne, dequantized_permuted_row[j], current_segment_src_offset);
13238+
// END DEBUG
13239+
13240+
13241+
current_segment_src_offset += ggml_row_size(segment_type, current_block_ne);
13242+
}
13243+
13244+
// DEBUG PRINT
13245+
// printf("DEBUG Dequant: Permuted row (first 8 vals): ");
13246+
// for(int k=0; k<8 && k < ne0; ++k) printf("%f ", dequantized_permuted_row[k]);
13247+
// printf("\n");
13248+
// END DEBUG
13249+
13250+
for (int64_t j_perm = 0; j_perm < ne0; ++j_perm) {
13251+
dst_row_final_target[tensor->sq_info->column_permutation[j_perm]] = dequantized_permuted_row[j_perm];
13252+
}
13253+
}
13254+
13255+
1315313256
static void ggml_compute_forward_get_rel_pos_f16(
1315413257
const struct ggml_compute_params * params,
1315513258
struct ggml_tensor * dst) {

ggml/src/ggml.c

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6510,15 +6510,27 @@ size_t ggml_quantize_chunk(
65106510
}
65116511

65126512
GGML_ASSERT(start % type_traits[type].blck_size == 0);
6513-
GGML_ASSERT(start % n_per_row == 0);
6513+
GGML_ASSERT(start % n_per_row == 0); // This should hold even for SmarterQuant calls where start is 0.
65146514

65156515
ggml_quantize_init(type); // this is noop if already initialized
65166516

6517-
const size_t start_row = start / n_per_row;
6517+
// DEBUG PRINT for Q2_K
6518+
// if (type == GGML_TYPE_Q2_K) {
6519+
// printf("DEBUG ggml_quantize_chunk: For Q2_K: type_size %zu, blck_size %lld, nrows %lld, n_per_row %lld\n",
6520+
// type_traits[type].type_size, (long long)type_traits[type].blck_size, (long long)nrows, (long long)n_per_row);
6521+
// }
6522+
// END DEBUG
6523+
6524+
const size_t start_row = start / n_per_row; // Now correctly declared before use.
65186525
const size_t row_size = ggml_row_size(type, n_per_row);
65196526

65206527
size_t result = 0;
65216528

6529+
// The dst pointer in these calls is ((char *) dst + start_row * row_size)
6530+
// For SmarterQuant calls to ggml_quantize_chunk, start is 0, so start_row is 0.
6531+
// The `dst` pointer passed *into* ggml_quantize_chunk by the SmarterQuant logic
6532+
// in llama-quant.cpp is already the correctly offset final destination pointer for the segment.
6533+
// So, `(char *)dst + 0 * row_size` correctly points to the beginning of the destination for the current segment.
65226534
switch (type) {
65236535
case GGML_TYPE_Q4_0: result = quantize_q4_0(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;
65246536
case GGML_TYPE_Q4_1: result = quantize_q4_1(src + start, (char *) dst + start_row * row_size, nrows, n_per_row, imatrix); break;

src/llama-impl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,3 +59,6 @@ std::string llama_format_tensor_shape(const std::vector<int64_t> & ne);
5959
std::string llama_format_tensor_shape(const struct ggml_tensor * t);
6060

6161
std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i);
62+
63+
// Function from llama-quant.cpp
64+
void llama_model_quantize_impl(const std::string & fname_inp, const std::string & fname_out, const struct llama_model_quantize_params * params);

0 commit comments

Comments
 (0)