-
Notifications
You must be signed in to change notification settings - Fork 4.4k
[WIP] gemm block quantization for llm decoder style #6439
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
17b7109
9c22a28
edbb1aa
6fd6909
4d636ba
2dfb93c
bdc9f2d
0cc494f
e9c4943
58cc1f3
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,7 +88,24 @@ int Gemm::load_model(const ModelBin& mb) | |
| if (transB == 0) | ||
| B_data = mb.load(constantN, constantK, 0); | ||
| else | ||
| B_data = mb.load(constantK, constantN, 0); | ||
| { | ||
| if (int8_scale_term == 5) | ||
| { | ||
| // int6 block quantize | ||
| // TODO auto type for int6 storage | ||
| B_data = mb.load((constantK + 3) / 4 * 3, constantN, 0); | ||
| } | ||
| else if (int8_scale_term == 6) | ||
| { | ||
| // int4 block quantize | ||
| // TODO auto type for int4 storage | ||
| B_data = mb.load((constantK + 1) / 2, constantN, 0); | ||
| } | ||
| else | ||
| { | ||
| B_data = mb.load(constantK, constantN, 0); | ||
| } | ||
| } | ||
| if (B_data.empty()) | ||
| return -100; | ||
| } | ||
|
|
@@ -119,7 +136,186 @@ int Gemm::load_model(const ModelBin& mb) | |
|
|
||
| if (constantB == 1) | ||
| { | ||
| B_data_int8_scale = mb.load(1, 1)[0]; | ||
| if (int8_scale_term == 4) | ||
| { | ||
| // int8 block quantize | ||
| // assert transB == 1 // FIXME hardcode | ||
| const int block_size = 64; // FIXME hardcode | ||
| const int block_count = (constantK + block_size - 1) / block_size; | ||
|
|
||
| B_data_quantize_scales = mb.load(block_count, constantN, 0); | ||
|
|
||
| // dequantize B_data to fp32 | ||
| Mat B_data_fp32(constantK, constantN); | ||
| if (B_data_fp32.empty()) | ||
| return -100; | ||
|
|
||
| for (int i = 0; i < constantN; i++) | ||
| { | ||
| const signed char* i8ptr = B_data.row<const signed char>(i); | ||
| float* ptr = B_data_fp32.row(i); | ||
| float* scale_ptr = B_data_quantize_scales.row(i); | ||
|
|
||
| for (int j = 0; j < block_count; j++) | ||
| { | ||
| // block dequantize | ||
| const signed char* i8ptr1 = i8ptr + j * block_size; | ||
| float scale = scale_ptr[j]; | ||
| if (scale == 0.f) | ||
| scale = 1.f; | ||
| const float inv_scale = 1.f / scale; | ||
| float* ptr1 = ptr + j * block_size; | ||
| const int block_size1 = std::min(block_size, constantK - j * block_size); | ||
|
|
||
| for (int k = 0; k < block_size1; k++) | ||
| { | ||
| ptr1[k] = i8ptr1[k] * inv_scale; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| B_data = B_data_fp32; | ||
|
|
||
| // reset int8_scale_term to use fp32 path | ||
| int8_scale_term = 0; | ||
| } | ||
| else if (int8_scale_term == 5) | ||
| { | ||
| // int6 block quantize | ||
| // assert transB == 1 // FIXME hardcode | ||
| const int block_size = 64; // FIXME hardcode | ||
| const int block_count = (constantK + block_size - 1) / block_size; | ||
|
|
||
| B_data_quantize_scales = mb.load(block_count, constantN, 0); | ||
|
|
||
| // dequantize B_data to fp32 | ||
| Mat B_data_fp32(constantK, constantN); | ||
| if (B_data_fp32.empty()) | ||
| return -100; | ||
|
|
||
| union i6x4_t | ||
| { | ||
| signed char i6[3]; | ||
| struct | ||
| { | ||
| signed char i6_a : 6; | ||
| signed char i6_b : 6; | ||
| signed char i6_c : 6; | ||
| signed char i6_d : 6; | ||
| } __attribute__((packed)); | ||
| }; | ||
|
|
||
| for (int i = 0; i < constantN; i++) | ||
| { | ||
| const i6x4_t* i6ptr = B_data.row<const i6x4_t>(i); | ||
| float* ptr = B_data_fp32.row(i); | ||
| float* scale_ptr = B_data_quantize_scales.row(i); | ||
|
|
||
| for (int j = 0; j < block_count; j++) | ||
| { | ||
| // block dequantize | ||
| const i6x4_t* i6ptr1 = i6ptr + j * block_size / 4; | ||
| // Prevent division by zero: if scale_ptr[j] == 0, use 1.0 as safe default | ||
| const float safe_scale = (scale_ptr[j] == 0.f) ? 1.f : scale_ptr[j]; | ||
| const float inv_scale = 1.f / safe_scale; | ||
| float* ptr1 = ptr + j * block_size; | ||
| const int block_size1 = std::min(block_size, constantK - j * block_size); | ||
|
|
||
| int k = 0; | ||
| for (; k + 3 < block_size1; k += 4) | ||
| { | ||
| ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale; | ||
| ptr1[k + 1] = i6ptr1[k / 4].i6_b * inv_scale; | ||
| ptr1[k + 2] = i6ptr1[k / 4].i6_c * inv_scale; | ||
| ptr1[k + 3] = i6ptr1[k / 4].i6_d * inv_scale; | ||
| } | ||
| for (; k + 2 < block_size1; k += 3) | ||
| { | ||
| ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale; | ||
| ptr1[k + 1] = i6ptr1[k / 4].i6_b * inv_scale; | ||
| ptr1[k + 2] = i6ptr1[k / 4].i6_c * inv_scale; | ||
| } | ||
| for (; k + 1 < block_size1; k += 2) | ||
| { | ||
| ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale; | ||
| ptr1[k + 1] = i6ptr1[k / 4].i6_b * inv_scale; | ||
| } | ||
| for (; k < block_size1; k++) | ||
| { | ||
| ptr1[k] = i6ptr1[k / 4].i6_a * inv_scale; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| B_data = B_data_fp32; | ||
|
|
||
| // reset int8_scale_term to use fp32 path | ||
| int8_scale_term = 0; | ||
| } | ||
| else if (int8_scale_term == 6) | ||
| { | ||
| // int4 block quantize | ||
| // assert transB == 1 // FIXME hardcode | ||
| const int block_size = 64; // FIXME hardcode | ||
| const int block_count = (constantK + block_size - 1) / block_size; | ||
|
|
||
| B_data_quantize_scales = mb.load(block_count, constantN, 0); | ||
|
|
||
| // dequantize B_data to fp32 | ||
| Mat B_data_fp32(constantK, constantN); | ||
| if (B_data_fp32.empty()) | ||
| return -100; | ||
|
|
||
| union i4x2_t | ||
| { | ||
| signed char i4; | ||
| struct | ||
| { | ||
| signed char i4_low : 4; | ||
| signed char i4_high : 4; | ||
| } __attribute__((packed)); | ||
|
||
| }; | ||
|
|
||
| for (int i = 0; i < constantN; i++) | ||
| { | ||
| const i4x2_t* i4ptr = B_data.row<const i4x2_t>(i); | ||
| float* ptr = B_data_fp32.row(i); | ||
| float* scale_ptr = B_data_quantize_scales.row(i); | ||
|
|
||
| for (int j = 0; j < block_count; j++) | ||
| { | ||
| // block dequantize | ||
| const i4x2_t* i4ptr1 = i4ptr + j * block_size / 2; | ||
| // Defensive: avoid division by zero | ||
| float scale = scale_ptr[j]; | ||
| if (scale == 0.f) | ||
| scale = 1.f; | ||
| const float inv_scale = 1.f / scale; | ||
| float* ptr1 = ptr + j * block_size; | ||
| const int block_size1 = std::min(block_size, constantK - j * block_size); | ||
|
|
||
| int k = 0; | ||
| for (; k + 1 < block_size1; k += 2) | ||
| { | ||
| ptr1[k] = i4ptr1[k / 2].i4_low * inv_scale; | ||
| ptr1[k + 1] = i4ptr1[k / 2].i4_high * inv_scale; | ||
| } | ||
| for (; k < block_size1; k++) | ||
| { | ||
| ptr1[k] = i4ptr1[k / 2].i4_low * inv_scale; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| B_data = B_data_fp32; | ||
|
|
||
| // reset int8_scale_term to use fp32 path | ||
| int8_scale_term = 0; | ||
| } | ||
| else | ||
| { | ||
| B_data_int8_scale = mb.load(1, 1)[0]; | ||
| } | ||
| } | ||
| } | ||
| #endif // NCNN_INT8 | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,6 +37,9 @@ endif() | |
| add_executable(ncnn2int8 ncnn2int8.cpp) | ||
| target_link_libraries(ncnn2int8 PRIVATE ncnn) | ||
|
|
||
| add_executable(ncnnllm2int468 ncnnllm2int468.cpp) | ||
| target_link_libraries(ncnnllm2int468 PRIVATE ncnn) | ||
|
Comment on lines
+40
to
+41
|
||
|
|
||
| # add ncnn2int8 tool to a virtual project group | ||
| set_property(TARGET ncnn2int8 PROPERTY FOLDER "tools/optimization") | ||
| ncnn_install_tool(ncnn2table) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
__attribute__((packed))attribute is GCC-specific and not portable. This will fail on MSVC. Consider using#pragma packfor cross-platform compatibility or conditionally compile based on compiler.