Skip to content

Commit 80fe9c3

Browse files
committed
Restore Cuda MMQ IQI1_S and Q4_1 kernels
1 parent 9a10002 commit 80fe9c3

File tree

3 files changed

+21
-110
lines changed

3 files changed

+21
-110
lines changed

CMakeLists.txt

Lines changed: 9 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ option(LLAMA_OPENMP "llama: use OpenMP"
6464

6565
# Croco.Cpp Specifics
6666
option(LLAMA_CUDA_FA_ALL_QUANTS "llama: compile 18 quants for FlashAttention" OFF)
67-
option(LLAMA_CUDA_DISABLE_MMQ_IQ1_S_Q4_1 "llama: compile 18 quants for FlashAttention" OFF)
6867
option(GGML_CUDA_USE_GRAPHS "Use Cuda Graphs to increase a bit performancess" OFF)
6968
set(GGML_SCHED_MAX_COPIES "1" CACHE STRING "llama: max input copies for pipeline parallelism")
7069
set(LLAMA_SCHED_MAX_COPIES "1" CACHE STRING "llama: max input copies for pipeline parallelism")
@@ -101,6 +100,8 @@ file(GLOB GGML_SOURCES_CUDA "ggml/src/ggml-cuda/*.cu")
101100
list(APPEND GGML_SOURCES_CUDA "ggml/src/ggml-cuda/ggml-cuda.cu")
102101
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/fattn-wmma*.cu")
103102
list(APPEND GGML_SOURCES_CUDA ${SRCS})
103+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
104+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
104105
set(GGML_V3_CUDA_SOURCES otherarch/ggml_v3-cuda.cu otherarch/ggml_v3-cuda.h)
105106
set(GGML_V2_CUDA_SOURCES otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h)
106107
set(GGML_V2_LEGACY_CUDA_SOURCES otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h)
@@ -160,55 +161,10 @@ if (LLAMA_CUBLAS)
160161
if (GGML_CUDA_USE_GRAPHS)
161162
add_compile_definitions(GGML_CUDA_USE_GRAPHS)
162163
endif()
163-
164-
if (LLAMA_CUDA_DISABLE_MMQ_IQ1_S_Q4_1)
165-
# all quants necessary for Kobold CPP Frankenstein are compiled
166-
# the other are ignored but not deleted from the ggml_cuda templates directory
167-
# file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu")
168-
# list(APPEND GGML_SOURCES_CUDA ${SRCS})
169-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_m.cu")
170-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
171-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu")
172-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
173-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu")
174-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
175-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu")
176-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
177-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu")
178-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
179-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu")
180-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
181-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu")
182-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
183-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu")
184-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
185-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu")
186-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
187-
# file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu")
188-
# list(APPEND GGML_SOURCES_CUDA ${SRCS})
189-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu")
190-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
191-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu")
192-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
193-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0.cu")
194-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
195-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu")
196-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
197-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu")
198-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
199-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu")
200-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
201-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu")
202-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
203-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu")
204-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
205-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu")
206-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
207-
else ()
208-
# Build All MMQ Kernels
209-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
210-
list(APPEND GGML_SOURCES_CUDA ${SRCS})
211-
endif()
164+
165+
# Build All MMQ Kernels
166+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
167+
list(APPEND GGML_SOURCES_CUDA ${SRCS})
212168

213169
if (LLAMA_CUDA_FA_ALL_QUANTS)
214170
# all quants necessary for Kobold CPP Frankenstein are compiled
@@ -374,54 +330,9 @@ if (LLAMA_HIPBLAS)
374330
target_compile_definitions(ggml-rocm PUBLIC GGML_CUDA_FORCE_DMMV)
375331
endif()
376332

377-
if (LLAMA_CUDA_DISABLE_MMQ_IQ1_S_Q4_1)
378-
# all quants necessary for Kobold CPP Frankenstein are compiled
379-
# the other are ignored but not deleted from the ggml_cuda templates directory
380-
# file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_s.cu")
381-
# list(APPEND GGML_SOURCES_ROCM ${SRCS})
382-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq1_m.cu")
383-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
384-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xxs.cu")
385-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
386-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_xs.cu")
387-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
388-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq2_s.cu")
389-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
390-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_xxs.cu")
391-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
392-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq3_s.cu")
393-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
394-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_xs.cu")
395-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
396-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-iq4_nl.cu")
397-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
398-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q4_0.cu")
399-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
400-
# file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q4_1.cu")
401-
# list(APPEND GGML_SOURCES_ROCM ${SRCS})
402-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q5_0.cu")
403-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
404-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q5_1.cu")
405-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
406-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q6_0.cu")
407-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
408-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q2_k.cu")
409-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
410-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q3_k.cu")
411-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
412-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q4_k.cu")
413-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
414-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q5_k.cu")
415-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
416-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q6_k.cu")
417-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
418-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq-instance-q8_0.cu")
419-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
420-
else ()
421-
# Build All MMQ Kernels
422-
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
423-
list(APPEND GGML_SOURCES_ROCM ${SRCS})
424-
endif()
333+
# Build All MMQ Kernels
334+
file(GLOB SRCS "ggml/src/ggml-cuda/template-instances/mmq*.cu")
335+
list(APPEND GGML_SOURCES_ROCM ${SRCS})
425336

426337
if (LLAMA_CUDA_FA_ALL_QUANTS)
427338
# all quants necessary for Kobold CPP Frankenstein are compiled

ggml/src/ggml-cuda/mmq.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ void ggml_cuda_op_mul_mat_q(
3434
case GGML_TYPE_Q4_0:
3535
mul_mat_q_case<GGML_TYPE_Q4_0>(ctx, args, stream);
3636
break;
37-
//case GGML_TYPE_Q4_1:
37+
case GGML_TYPE_Q4_1:
3838
mul_mat_q_case<GGML_TYPE_Q4_1>(ctx, args, stream);
3939
break;
4040
case GGML_TYPE_Q5_0:
@@ -82,9 +82,9 @@ void ggml_cuda_op_mul_mat_q(
8282
case GGML_TYPE_IQ3_S:
8383
mul_mat_q_case<GGML_TYPE_IQ3_S>(ctx, args, stream);
8484
break;
85-
//case GGML_TYPE_IQ1_S:
86-
//mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
87-
//break;
85+
case GGML_TYPE_IQ1_S:
86+
mul_mat_q_case<GGML_TYPE_IQ1_S>(ctx, args, stream);
87+
break;
8888
case GGML_TYPE_IQ4_XS:
8989
mul_mat_q_case<GGML_TYPE_IQ4_XS>(ctx, args, stream);
9090
break;
@@ -112,7 +112,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
112112

113113
switch (type) {
114114
case GGML_TYPE_Q4_0:
115-
//case GGML_TYPE_Q4_1:
115+
case GGML_TYPE_Q4_1:
116116
case GGML_TYPE_Q5_0:
117117
case GGML_TYPE_Q5_1:
118118
case GGML_TYPE_Q6_0:
@@ -128,7 +128,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
128128
case GGML_TYPE_IQ2_S:
129129
case GGML_TYPE_IQ3_XXS:
130130
case GGML_TYPE_IQ3_S:
131-
//case GGML_TYPE_IQ1_S:
131+
case GGML_TYPE_IQ1_S:
132132
case GGML_TYPE_IQ4_XS:
133133
case GGML_TYPE_IQ4_NL:
134134
mmq_supported = true;

ggml/src/ggml-cuda/mmq.cuh

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ static constexpr __device__ int get_mmq_y_device() {
156156

157157
static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml_type type, int mmq_y) {
158158
return type == GGML_TYPE_Q4_0 ? MMQ_DP4A_TXS_Q4_0 :
159-
//type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
159+
type == GGML_TYPE_Q4_1 ? MMQ_DP4A_TXS_Q4_1 :
160160
type == GGML_TYPE_Q5_0 ? MMQ_DP4A_TXS_Q8_0 :
161161
type == GGML_TYPE_Q5_1 ? MMQ_DP4A_TXS_Q8_1 :
162162
type == GGML_TYPE_Q6_0 ? MMQ_DP4A_TXS_Q8_0 :
@@ -172,7 +172,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
172172
type == GGML_TYPE_IQ2_S ? MMQ_DP4A_TXS_Q8_0_16 :
173173
type == GGML_TYPE_IQ3_XXS ? MMQ_DP4A_TXS_Q8_0 :
174174
type == GGML_TYPE_IQ3_S ? MMQ_DP4A_TXS_Q8_0 :
175-
//type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
175+
type == GGML_TYPE_IQ1_S ? MMQ_DP4A_TXS_Q8_0 :
176176
type == GGML_TYPE_IQ4_XS ? MMQ_DP4A_TXS_Q8_0 :
177177
type == GGML_TYPE_IQ4_NL ? MMQ_DP4A_TXS_Q8_0 :
178178
tile_x_sizes{0, 0, 0};
@@ -192,7 +192,7 @@ static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
192192

193193
static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
194194
return type == GGML_TYPE_Q4_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
195-
//type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
195+
type == GGML_TYPE_Q4_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
196196
type == GGML_TYPE_Q5_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
197197
type == GGML_TYPE_Q5_1 ? MMQ_MMA_TILE_X_K_Q8_1 :
198198
type == GGML_TYPE_Q6_0 ? MMQ_MMA_TILE_X_K_Q8_0 :
@@ -208,7 +208,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
208208
type == GGML_TYPE_IQ2_S ? MMQ_MMA_TILE_X_K_Q3_K :
209209
type == GGML_TYPE_IQ3_XXS ? MMQ_MMA_TILE_X_K_Q8_0 :
210210
type == GGML_TYPE_IQ3_S ? MMQ_MMA_TILE_X_K_Q8_0 :
211-
//type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
211+
type == GGML_TYPE_IQ1_S ? MMQ_MMA_TILE_X_K_Q8_0 :
212212
type == GGML_TYPE_IQ4_XS ? MMQ_MMA_TILE_X_K_Q8_0 :
213213
type == GGML_TYPE_IQ4_NL ? MMQ_MMA_TILE_X_K_Q8_0 :
214214
0;
@@ -3058,7 +3058,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
30583058
template void mul_mat_q_case<type>(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) \
30593059

30603060
extern DECL_MMQ_CASE(GGML_TYPE_Q4_0);
3061-
//extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
3061+
extern DECL_MMQ_CASE(GGML_TYPE_Q4_1);
30623062
extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
30633063
extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
30643064
extern DECL_MMQ_CASE(GGML_TYPE_Q6_0);
@@ -3074,7 +3074,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_IQ2_XS);
30743074
extern DECL_MMQ_CASE(GGML_TYPE_IQ2_S);
30753075
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_XXS);
30763076
extern DECL_MMQ_CASE(GGML_TYPE_IQ3_S);
3077-
//extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
3077+
extern DECL_MMQ_CASE(GGML_TYPE_IQ1_S);
30783078
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_NL);
30793079
extern DECL_MMQ_CASE(GGML_TYPE_IQ4_XS);
30803080

0 commit comments

Comments
 (0)