Skip to content

Commit 7117c23

Browse files
ikawrakowIwan KawrakowiSevenDaysTheLegendOfKitty
authored
* mxfp4: basics * mxfp4: Zen4 GEMM * mxfp4: repacked GEMM (AVX2/Zen4) * mxfp4: AVX2 GEMM * mxfp4: NEON GEMM * mxfp4: repacked GEMM (NEON) * mxfp4: Metal * Fix quantized K cache without FA (ikawrakow#680) * Prevent assert with quantized K cache and no FA * Fix MMQ when running with quantized K cache without FA --------- Co-authored-by: Iwan Kawrakow <[email protected]> * Fix for Deepseek r1 parsing (ikawrakow#676) * Implement function calling / tools for ik_llama.cpp for Kimi K2 * Implement basic tool choice * Backport llama.cpp tool calls support * Enhance function calls with improved chat parser and string utilities - Add new chat.h/chat.cpp and chat-parser.h/chat-parser.cpp for better chat handling - Improve function calls parsing with fallback to llama.cpp builder pattern - Add string utility functions (starts_with, ends_with, find_partial_stop) - Update README with function calls testing instructions - Enhance Kimi K2 parser and function calls documentation - Add comprehensive test suite for function calls - Update CMakeLists.txt and Makefile for new components * Enhance function calling with unified streaming and parser improvements - Fix streaming content cleanup to prevent function syntax in output - Unify content extraction patterns with llama.cpp approach - Improve Kimi K2 parser robustness and partial content handling - Add comprehensive test coverage for function call scenarios - Optimize chat message parsing and diff computation * Replace hardcoded values in kimi_k2_parser.hpp with named constants - Add compile-time constants for all token format markers - Add compile-time constants for XML format markers - Add compile-time constants for simple format patterns - Replace all hardcoded string literals with named constants - Use compile-time length calculation to avoid manual counting - Improve maintainability and reduce magic numbers throughout parser * Fix duplicate common_chat_parse definition - Remove duplicate implementation from chat-parser.cpp - Keep single implementation in chat.cpp following llama.cpp patterns - Resolves linker error: multiple definition of common_chat_parse * Fix JSON assertion failure in function call parsing - Add proper validation that 'function' field is an object before accessing nested keys - Handle missing 'arguments' field gracefully with default "{}" - Prevents crash when parsing malformed tool call JSON structures * Add comprehensive Qwen3 XML tool calling support with unit tests - Implement Qwen3 XML parser with <tool_call>{"name": "func", "arguments": {...}}</tool_call> format - Add model detection and routing for Qwen3 vs Kimi-K2 formats - Create 8 comprehensive unit tests covering parsing, streaming, error handling - Fix token format cleaning bug in kimi_k2_parser.hpp processing order - Remove progressive parsing code and related utilities - Add tool injection support for Qwen3 format in server utils * Add DeepSeek R1 function calling support with comprehensive unit tests - Implement complete DeepSeek R1 tool call parsing in common_chat_parser.cpp - Add DeepSeek R1 model detection and tool injection in deepseek_r1_tools.hpp - Update function_calls.hpp with DeepSeek R1 integration and content extraction - Update documentation to reflect support for Kimi-K2, Qwen3, and DeepSeek R1 models - Add comprehensive unit tests for DeepSeek R1 reasoning, tool calls, and integration - Port exact implementation patterns from original llama.cpp for compatibility Key features: - Native DeepSeek R1 format: <|tool▁calls▁begin|>function<|tool▁sep|>name```json{}```<|tool▁call▁end|><|tool▁calls▁end|> - Reasoning content extraction from <think>...</think> tags - Multiple tool calls support with separate call blocks - Model detection for deepseek-r1, deepseek_r1 naming patterns - Integration with incremental parsing and streaming support * Add partial parsing support for JSON and regex - json-partial.h/cpp: JSON partial parsing functionality - regex-partial.h/cpp: Regex partial parsing functionality * Add format_chat integration tests for Qwen3 tool injection - Add test_qwen3_format_chat_integration() to validate tool injection pipeline - Test tool injection conditions and system message enhancement - Verify JSON formatting and anti-preamble instructions - Add comprehensive test documentation Tests confirm tool injection works correctly - conversational preamble issue is not in ik_llama.cpp but likely in UI configuration. * Fix Qwen3 tool call parsing - pass model name to parser Server was not passing model name to parse_chat_message_incremental(), causing Qwen3 to fall back to Kimi-K2 parser and return tool calls as content instead of proper tool_calls array. * Fix non-streaming path to use model-specific parsing Non-streaming responses were hardcoded to use Kimi-K2 format, causing Qwen3 XML tool calls to be returned as content instead of proper tool_calls array. Now uses same model detection as streaming path for consistency. * Update Qwen3 function call handling in server and tests - Enhanced server function call detection and response formatting - Improved test coverage for Qwen3 tool call scenarios - Refined XML parsing for better tool execution support * Add DeepSeek-R1 function call parsing support Implements comprehensive parsing for all 4 DeepSeek-R1 function call formats: - Format 1: Standard function call syntax (already supported) - Format 2: Alternative function call patterns (already supported) - Format 3: Tools array format - function\n```json\n{"tools": [...]} - Format 4: XML wrapped format - <tool_call>function</think>Name\n```json\n{...}```</tool_call> Key changes: - Added parse_deepseek_r1_tools_array() following original parse_prefixed_json_tool_call_array pattern - Added parse_deepseek_r1_xml_wrapped() following Hermes-2-Pro XML wrapper patterns - Integrated both parsers into exception handling chain for robust fallback - Added comprehensive TDD test coverage for all formats - Anonymized all confidential information while preserving functionality Resolves tool_calls_count=0 issue where DeepSeek-R1 models generated valid tool calls but server failed to parse them correctly. * Update function_calls.md documentation for DeepSeek-R1 Format 4 - Added Format 4 (XML wrapped) documentation with examples - Updated implementation notes with correct parser order (3→4→1→2) - Marked all DeepSeek-R1 formats as working (July 2025 update) - Updated test status for Format 3 and 4 as passing - Added parse_deepseek_r1_xml_wrapped() function reference - Corrected implementation file line numbers * Fix merge conflict in test-function-calls.cpp - Removed incomplete merge conflict marker from line 3027 - Ensured all tests compile and pass successfully - All DeepSeek-R1 formats (1-4) working correctly - All streaming and content cleaning tests passing * Fix DeepSeek R1 parsing issue with responses wrapped in think tags Restore missing consume_rest() call from working PR ikawrakow#648 implementation. When responses don't contain tool calls, remaining content after reasoning parsing must be preserved as displayable content. Fixes issue where entire responses wrapped in <think> tags resulted in empty content output. * Implement proper reasoning handling following original llama.cpp patterns - Add missing reasoning_format and reasoning_in_content fields to common_chat_syntax - Update try_parse_reasoning to match original llama.cpp logic exactly - Add TDD test case with reasoning_in_content=true for DeepSeek R1 - Following TDD: test should now pass with proper syntax configuration Based on original llama.cpp implementation patterns. * TDD SUCCESS: Fix DeepSeek R1 thinking tag termination issue ✅ Test passes with reasoning_in_content=true configuration - Content properly preserved: '<think>content</think>' displays fully - Reasoning field empty as expected - Following TDD: test-first approach validates the fix Next: Update server to automatically apply this configuration. * Complete server integration fix for DeepSeek R1 thinking tag termination - Server now automatically sets reasoning_in_content=true for DeepSeek R1 models - Fixes issue where responses wrapped in <think> tags appear empty to users * Add TDD test case for DeepSeek R1 thinking tag termination issue - Test reproduces the exact failure scenario reported by user - Validates that reasoning_in_content=true fixes the issue - Demonstrates empty content problem and working solution * Add remaining TDD test changes for DeepSeek R1 thinking tag fix * Add debug output after upstream merge * Remove temporary benchmark and debug files - Remove tests/benchmark-progressive-parsing.cpp (development tool, not part of core functionality) - Remove tests/reproduce_bug.sh (debugging script, not needed for PR) * Port cpu moe options from mainline (ikawrakow#672) * Port cpu moe options from mainline * Use strdup and int32_t to follow coding guidelines * maxfp4: CUDA dequantize * mxfp4: CUDA GEMV * mxfp4: CUDA MMQ * mxfp4: minor CUDA tweaks --------- Co-authored-by: Iwan Kawrakow <[email protected]> Co-authored-by: Anton Sokolchenko <[email protected]> Co-authored-by: Parsa <[email protected]>
1 parent 293f4aa commit 7117c23

File tree

22 files changed

+733
-38
lines changed

22 files changed

+733
-38
lines changed

examples/quantize/quantize.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ static const std::vector<struct quant_option> QUANT_OPTIONS = {
2828
{ "Q5_0", LLAMA_FTYPE_MOSTLY_Q5_0, " 4.33G, +0.0683 ppl @ LLaMA-v1-7B", },
2929
{ "Q5_1", LLAMA_FTYPE_MOSTLY_Q5_1, " 4.70G, +0.0349 ppl @ LLaMA-v1-7B", },
3030
{ "Q6_0", LLAMA_FTYPE_MOSTLY_Q6_0, " 6.5 bpw quantization", },
31+
{ "MXFP4", LLAMA_FTYPE_MOSTLY_MXFP4, " 4.25 bpw 4-bit float quantization",},
3132
{ "IQ2_XXS", LLAMA_FTYPE_MOSTLY_IQ2_XXS, " 2.06 bpw quantization", },
3233
{ "IQ2_XXS_R4",LLAMA_FTYPE_MOSTLY_IQ2_XXS_R4,"IQ2_XXS repacked", },
3334
{ "IQ2_XS", LLAMA_FTYPE_MOSTLY_IQ2_XS, " 2.31 bpw quantization", },

ggml/include/ggml.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ extern "C" {
403403
GGML_TYPE_Q4_0_4_4 = 31,
404404
GGML_TYPE_Q4_0_4_8 = 32,
405405
GGML_TYPE_Q4_0_8_8 = 33,
406+
GGML_TYPE_MXFP4 = 39, // so we are compatible with mainline
406407
//
407408
// So we are able to consume MS BitNet I2_S quants
408409
//
@@ -507,9 +508,10 @@ extern "C" {
507508
GGML_FTYPE_MOSTLY_IQ4_XS = 22, // except 1d tensors
508509
GGML_FTYPE_MOSTLY_IQ1_M = 23, // except 1d tensors
509510
GGML_FTYPE_MOSTLY_BF16 = 24, // except 1d tensors
510-
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 25, // except 1d tensors
511-
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 26, // except 1d tensors
512-
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 27, // except 1d tensors
511+
GGML_FTYPE_MOSTLY_MXFP4 = 25, // except 1d tensors, using 26 to be compatible with mainline
512+
GGML_FTYPE_MOSTLY_Q4_0_4_4 = 26, // except 1d tensors
513+
GGML_FTYPE_MOSTLY_Q4_0_4_8 = 27, // except 1d tensors
514+
GGML_FTYPE_MOSTLY_Q4_0_8_8 = 28, // except 1d tensors
513515
//
514516
GGML_FTYPE_MOSTLY_Q6_0 = 127, // except 1d tensors
515517
GGML_FTYPE_MOSTLY_IQ1_BN = 128, // except 1d tensors

ggml/src/ggml-common.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ typedef sycl::half2 ggml_half2;
158158
#define QI1_BN (QK_IQ1BN / (4*QR1_BN))
159159
#define QR1_BN 8
160160

161+
#define QI_MXFP4 (QK_MXFP4 / (4 * QR_MXFP4))
162+
#define QR_MXFP4 2
163+
161164
#endif // GGML_COMMON_DECL_CUDA || GGML_COMMON_DECL_HIP
162165

163166
#define QK4_0 32
@@ -174,6 +177,15 @@ typedef struct {
174177
} block_q4_1;
175178
static_assert(sizeof(block_q4_1) == 2 * sizeof(ggml_half) + QK4_1 / 2, "wrong q4_1 block size/padding");
176179

180+
// This is unfortunate (block is 17 bytes, so not even a 2-byte alignment)
181+
// But to be able to use MXFP4-quantized models from mainline, we do the same.
182+
#define QK_MXFP4 32
183+
typedef struct {
184+
uint8_t e; // E8M0
185+
uint8_t qs[QK_MXFP4/2];
186+
} block_mxfp4;
187+
static_assert(sizeof(block_mxfp4) == sizeof(uint8_t) + QK_MXFP4/2, "wrong mxfp4 block size/padding");
188+
177189
#define QK5_0 32
178190
typedef struct {
179191
ggml_half d; // delta
@@ -2211,5 +2223,11 @@ GGML_TABLE_BEGIN(int8_t, iq6nl_values, 128)
22112223
48, 52, 56, 60, 64, 69, 73, 78, 83, 88, 93, 99, 104, 110, 116, 122,
22122224
GGML_TABLE_END()
22132225

2226+
// e2m1 values (doubled)
2227+
// ref: https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
2228+
GGML_TABLE_BEGIN(int8_t, kvalues_mxfp4, 16)
2229+
0, 1, 2, 3, 4, 6, 8, 12, 0, -1, -2, -3, -4, -6, -8, -12,
2230+
GGML_TABLE_END()
2231+
22142232
#endif // GGML_COMMON_IMPL
22152233
#endif // GGML_COMMON_IMPL

ggml/src/ggml-cuda.cu

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3498,6 +3498,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
34983498
case GGML_TYPE_IQ3_S:
34993499
case GGML_TYPE_IQ3_XXS:
35003500
case GGML_TYPE_IQ4_NL:
3501+
case GGML_TYPE_MXFP4:
35013502
case GGML_TYPE_IQ4_XS:
35023503
case GGML_TYPE_IQ2_KL:
35033504
case GGML_TYPE_IQ3_KS:

ggml/src/ggml-cuda/common.cuh

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_IQ4_NL> {
550550
static constexpr int qi = QI4_NL;
551551
};
552552

553+
template<>
554+
struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
555+
static constexpr int qk = QK4_NL;
556+
static constexpr int qr = QR4_NL;
557+
static constexpr int qi = QI4_NL;
558+
};
559+
553560
template<>
554561
struct ggml_cuda_type_traits<GGML_TYPE_IQ4_XS> {
555562
static constexpr int qk = QK_K;

ggml/src/ggml-cuda/convert.cu

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -736,6 +736,27 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
736736
}
737737
}
738738

739+
template<typename dst_t>
740+
static __global__ void dequantize_block_mxfp4(const void * __restrict__ vx, dst_t * __restrict__ yy) {
741+
742+
constexpr uint32_t uval[2] = { 0x00200000, 0x00400000 };
743+
const int64_t i = blockIdx.x;
744+
const block_mxfp4 * x = (const block_mxfp4 *) vx + i*(QK_K/QK4_NL);
745+
746+
const int64_t tid = threadIdx.x;
747+
const int64_t il = tid/8; // 0...3
748+
const int64_t ib = tid%8; // 0...7
749+
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
750+
const uint8_t * q4 = x[ib].qs + 4*il;
751+
union { float f; uint32_t u; } helper;
752+
helper.u = x[ib].e >= 2 ? uint32_t(x[ib].e - 1) << 23u : uval[x[ib].e];
753+
const float d = helper.f;
754+
for (int j = 0; j < 4; ++j) {
755+
y[j+ 0] = d * kvalues_mxfp4[q4[j] & 0xf];
756+
y[j+16] = d * kvalues_mxfp4[q4[j] >> 4];
757+
}
758+
}
759+
739760
template<typename dst_t>
740761
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
741762
const int64_t i = blockIdx.x;
@@ -1611,6 +1632,13 @@ static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int64_t
16111632
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
16121633
}
16131634

1635+
template<typename dst_t>
1636+
static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
1637+
const int64_t k = nrows * n_per_row;
1638+
const int nb = (k + QK_K - 1) / QK_K;
1639+
dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
1640+
}
1641+
16141642
template<typename dst_t>
16151643
static void dequantize_row_iq1_m_cuda(const void * vx, dst_t * y, const int64_t nrows, const int64_t n_per_row, cudaStream_t stream) {
16161644
const int64_t k = nrows * n_per_row;
@@ -1943,6 +1971,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
19431971
return dequantize_row_iq2_bn_cuda;
19441972
case GGML_TYPE_IQ4_NL:
19451973
return dequantize_row_iq4_nl_cuda;
1974+
case GGML_TYPE_MXFP4:
1975+
return dequantize_row_mxfp4_cuda;
19461976
case GGML_TYPE_IQ4_XS:
19471977
return dequantize_row_iq4_xs_cuda;
19481978
case GGML_TYPE_IQ4_KS:
@@ -2044,6 +2074,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
20442074
return dequantize_row_iq2_bn_cuda;
20452075
case GGML_TYPE_IQ4_NL:
20462076
return dequantize_row_iq4_nl_cuda;
2077+
case GGML_TYPE_MXFP4:
2078+
return dequantize_row_mxfp4_cuda;
20472079
case GGML_TYPE_IQ4_XS:
20482080
return dequantize_row_iq4_xs_cuda;
20492081
case GGML_TYPE_IQ4_KS:

ggml/src/ggml-cuda/mmq.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ void ggml_cuda_op_mul_mat_q(
9494
case GGML_TYPE_IQ4_NL:
9595
mul_mat_q_case<GGML_TYPE_IQ4_NL>(ctx, args, stream);
9696
break;
97+
case GGML_TYPE_MXFP4:
98+
mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
99+
break;
97100
case GGML_TYPE_IQ2_KL:
98101
mul_mat_q_case<GGML_TYPE_IQ2_KL>(ctx, args, stream);
99102
break;
@@ -210,6 +213,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
210213
case GGML_TYPE_IQ1_S_R4:
211214
case GGML_TYPE_IQ4_XS:
212215
case GGML_TYPE_IQ4_NL:
216+
case GGML_TYPE_MXFP4:
213217
case GGML_TYPE_IQ2_KL:
214218
case GGML_TYPE_IQ3_KS:
215219
case GGML_TYPE_IQ4_KSS:

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
8484
return MMQ_Q8_1_DS_LAYOUT_DS4;
8585
case GGML_TYPE_IQ4_XS:
8686
case GGML_TYPE_IQ4_NL:
87+
case GGML_TYPE_MXFP4:
8788
case GGML_TYPE_IQ2_KS:
8889
case GGML_TYPE_IQ2_K:
8990
case GGML_TYPE_IQ2_K_R4:
@@ -204,6 +205,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
204205
case GGML_TYPE_IQ1_S_R4: return MMQ_DP4A_TXS_Q8_0;
205206
case GGML_TYPE_IQ4_XS : return MMQ_DP4A_TXS_Q8_0;
206207
case GGML_TYPE_IQ4_NL : return MMQ_DP4A_TXS_Q8_0;
208+
case GGML_TYPE_MXFP4 : return MMQ_DP4A_TXS_Q8_0;
207209
case GGML_TYPE_IQ2_KL : return MMQ_DP4A_TXS_Q8_0;
208210
case GGML_TYPE_IQ3_KS : return MMQ_DP4A_TXS_Q8_0;
209211
case GGML_TYPE_IQ4_KSS : return MMQ_DP4A_TXS_Q8_0;
@@ -263,6 +265,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
263265
case GGML_TYPE_IQ1_S_R4: return MMQ_MMA_TILE_X_K_Q8_0;
264266
case GGML_TYPE_IQ4_XS : return MMQ_MMA_TILE_X_K_Q8_0;
265267
case GGML_TYPE_IQ4_NL : return MMQ_MMA_TILE_X_K_Q8_0;
268+
case GGML_TYPE_MXFP4 : return MMQ_MMA_TILE_X_K_Q8_0;
266269
case GGML_TYPE_IQ2_KL : return MMQ_MMA_TILE_X_K_Q8_0;
267270
case GGML_TYPE_IQ3_KS : return MMQ_MMA_TILE_X_K_Q8_0;
268271
case GGML_TYPE_IQ4_KSS : return MMQ_MMA_TILE_X_K_Q8_0;
@@ -2078,6 +2081,67 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
20782081
}
20792082
}
20802083

2084+
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_mxfp4(
2085+
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2086+
2087+
#ifdef INT8_MMA_AVAILABLE
2088+
int * x_qs = (int *) x_tile;
2089+
float * x_df = (float *) (x_qs + WARP_SIZE*2);
2090+
#else
2091+
constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_MXFP4, mmq_y);
2092+
int * x_qs = (int *) x_tile;
2093+
float * x_df = (float *) (x_qs + txs.qs);
2094+
#endif // INT8_MMA_AVAILABLE
2095+
2096+
const int kbx = threadIdx.x / QI4_NL;
2097+
const int kqsx = threadIdx.x % QI4_NL;
2098+
2099+
#pragma unroll
2100+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2101+
int i = i0 + threadIdx.y;
2102+
2103+
if (need_check) {
2104+
i = min(i, i_max);
2105+
}
2106+
2107+
const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbx;
2108+
2109+
const int aux_q4 = get_int_b1(bxi->qs, kqsx);
2110+
const int2 v = get_int_from_table_16(aux_q4, kvalues_mxfp4);
2111+
const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2112+
#ifdef INT8_MMA_AVAILABLE
2113+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2114+
x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2115+
#else
2116+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2117+
x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2118+
#endif // INT8_MMA_AVAILABLE
2119+
}
2120+
2121+
const int blocks_per_tile_x_row = WARP_SIZE / QI4_NL;
2122+
const int kbxd = threadIdx.x % blocks_per_tile_x_row;
2123+
2124+
union { float f; uint32_t u; } helper;
2125+
2126+
#pragma unroll
2127+
for (int i0 = 0; i0 < mmq_y; i0 += nwarps * QI4_NL) {
2128+
int i = i0 + threadIdx.y * QI4_NL + threadIdx.x / blocks_per_tile_x_row;
2129+
2130+
if (need_check) {
2131+
i = min(i, i_max);
2132+
}
2133+
2134+
const block_mxfp4 * bxi = (const block_mxfp4 *)(x + i*stride) + kbx0 + kbxd;
2135+
helper.u = bxi->e ? uint32_t(bxi->e) << 23u : 0x00400000;
2136+
2137+
#ifdef INT8_MMA_AVAILABLE
2138+
x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kbxd] = 0.5f * helper.f;
2139+
#else
2140+
x_df[i*(WARP_SIZE/4) + i/4 + kbxd] = 0.5f * helper.f;
2141+
#endif // INT8_MMA_AVAILABLE
2142+
}
2143+
}
2144+
20812145
template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq2_xxs(
20822146
const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
20832147

@@ -3624,6 +3688,13 @@ struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_NL> {
36243688
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
36253689
};
36263690

3691+
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
3692+
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_MXFP4> {
3693+
static constexpr load_tiles_mmq_t load_tiles = load_tiles_mxfp4<mmq_y, nwarps, need_check>;
3694+
static constexpr vec_dot_mmq_t vec_dot_mma = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, nwarps, MMQ_Q8_1_DS_LAYOUT_D4>;
3695+
static constexpr vec_dot_mmq_t vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y, nwarps>;
3696+
};
3697+
36273698
template <int mmq_x, int mmq_y, int nwarps, bool need_check>
36283699
struct mmq_type_traits<mmq_x, mmq_y, nwarps, need_check, GGML_TYPE_IQ4_XS> {
36293700
static constexpr load_tiles_mmq_t load_tiles = load_tiles_iq4_xs<mmq_y, nwarps, need_check>;
@@ -4164,6 +4235,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
41644235
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
41654236
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
41664237
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
4238+
extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
41674239
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
41684240
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_KL);
41694241
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_KS);

ggml/src/ggml-cuda/mmvq.cu

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
3131
case GGML_TYPE_IQ1_S : return vec_dot_iq1_s_q8_1;
3232
case GGML_TYPE_IQ1_M : return vec_dot_iq1_m_q8_1;
3333
case GGML_TYPE_IQ4_NL : return vec_dot_iq4_nl_q8_1;
34+
case GGML_TYPE_MXFP4 : return vec_dot_mxfp4_q8_1;
3435
case GGML_TYPE_IQ4_XS : return vec_dot_iq4_xs_q8_1;
3536
case GGML_TYPE_IQ3_S : return vec_dot_iq3_s_q8_1;
3637
default : return nullptr;
@@ -56,6 +57,7 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
5657
case GGML_TYPE_IQ3_XXS : return VDR_IQ3_XXS_Q8_1_MMVQ;
5758
case GGML_TYPE_IQ3_S : return VDR_IQ3_S_Q8_1_MMVQ;
5859
case GGML_TYPE_IQ4_NL : return VDR_IQ4_NL_Q8_1_MMVQ;
60+
case GGML_TYPE_MXFP4 : return VDR_MXFP4_Q8_1_MMVQ;
5961
case GGML_TYPE_IQ4_XS : return VDR_IQ4_XS_Q8_1_MMVQ;
6062
default : return 1;
6163
}
@@ -417,6 +419,14 @@ static void mul_mat_vec_iq4_nl_q8_1_cuda(
417419
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
418420
}
419421

422+
static void mul_mat_vec_mxfp4_q8_1_cuda(
423+
const void * vx, const void * vy, float * dst, const char * ids_data,
424+
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
425+
const int ne2, const uint64_t nb02, const uint64_t nb12, const uint64_t nb2, int64_t ids_nb0, cudaStream_t stream) {
426+
427+
mul_mat_vec_q_cuda<GGML_TYPE_MXFP4>(vx, vy, dst, ids_data, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
428+
}
429+
420430
static void mul_mat_vec_iq4_xs_q8_1_cuda(
421431
const void * vx, const void * vy, float * dst, const char * ids_data,
422432
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst,
@@ -509,6 +519,9 @@ static void ggml_cuda_op_mul_mat_vec_q_impl(ggml_backend_cuda_context & ctx, ggm
509519
case GGML_TYPE_IQ4_NL:
510520
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
511521
break;
522+
case GGML_TYPE_MXFP4:
523+
mul_mat_vec_mxfp4_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
524+
break;
512525
case GGML_TYPE_IQ4_XS:
513526
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ids_data, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, ne2, nb02, nb12, nb2, ids_nb0, stream);
514527
break;
@@ -686,6 +699,7 @@ bool ggml_cuda_mmvq_type_supported(ggml_type src0_type) {
686699
case GGML_TYPE_IQ1_BN:
687700
case GGML_TYPE_IQ2_BN:
688701
case GGML_TYPE_IQ4_NL:
702+
case GGML_TYPE_MXFP4:
689703
case GGML_TYPE_IQ4_XS:
690704
case GGML_TYPE_IQ2_K:
691705
case GGML_TYPE_IQ2_KL:
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
2-
31
#include "../mmq.cuh"
42

53
DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
4+
DECL_MMQ_CASE(GGML_TYPE_MXFP4);

0 commit comments

Comments
 (0)