1818#include < excpt.h>
1919#endif
2020
21- #include " ggml- kleidiai.h"
21+ #include " kleidiai.h"
2222
2323#include " ggml-cpu.h"
2424#include " ggml-impl.h"
2525#include " ggml-backend-impl.h"
2626#include " ggml-threading.h"
2727
28- #include " kleidiai_kernels .h"
28+ #include " kernels .h"
2929
3030#include " kai_common.h"
3131
32- static const size_t k_q4_0_block_size = 32 ;
32+ #define GGML_COMMON_DECL_CPP
33+ #include " ggml-common.h"
3334
3435struct ggml_kleidiai_context {
3536 ggml_kleidiai_kernels * kernels;
@@ -78,9 +79,8 @@ class tensor_traits : public ggml::cpu::tensor_traits {
7879 size_t mr = kernel->get_mr ();
7980 size_t kr = kernel->get_kr ();
8081 size_t sr = kernel->get_sr ();
81- size_t bl = k_q4_0_block_size;
8282
83- size = ctx.kernels ->lhs_info .packed_size (m, k, bl , mr, kr, sr);
83+ size = ctx.kernels ->lhs_info .packed_size (m, k, QK4_0 , mr, kr, sr);
8484
8585 return true ;
8686 }
@@ -121,7 +121,6 @@ class tensor_traits : public ggml::cpu::tensor_traits {
121121 size_t mr = kernel->get_mr ();
122122 size_t kr = kernel->get_kr ();
123123 size_t sr = kernel->get_sr ();
124- size_t bl = k_q4_0_block_size;
125124
126125 // Calculate number of columns to be processed per thread
127126 const size_t num_m_per_thread = kai_roundup (m, nth) / nth;
@@ -135,24 +134,24 @@ class tensor_traits : public ggml::cpu::tensor_traits {
135134 // Transform LHS
136135 const size_t src_stride = src1->nb [1 ];
137136 const float * src_ptr = reinterpret_cast <const float *>(lhs + lhs_info->get_offset (0 , dst->src [1 ]->nb [1 ]));
138- const size_t lhs_packed_offset = lhs_info->get_packed_offset (m_start, k, bl , mr, kr, sr);
137+ const size_t lhs_packed_offset = lhs_info->get_packed_offset (m_start, k, QK4_0 , mr, kr, sr);
139138 void * lhs_packed_ptr = static_cast <void *>(lhs_packed + lhs_packed_offset);
140139
141- lhs_info->pack_func (m_to_process, k, bl , mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
140+ lhs_info->pack_func (m_to_process, k, QK4_0 , mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
142141 }
143142
144143 ggml_barrier (params->threadpool );
145144
146145 // Perform the operation
147146 const size_t dst_stride = dst->nb [1 ];
148- const size_t lhs_packed_offset = lhs_info->get_packed_offset (0 , k, k_q4_0_block_size , mr, kr, sr);
149- const size_t rhs_packed_offset = kernel->get_rhs_packed_offset (n_start, k, k_q4_0_block_size );
147+ const size_t lhs_packed_offset = lhs_info->get_packed_offset (0 , k, QK4_0 , mr, kr, sr);
148+ const size_t rhs_packed_offset = kernel->get_rhs_packed_offset (n_start, k, QK4_0 );
150149 const size_t dst_offset = kernel->get_dst_offset (0 , n_start, dst_stride);
151150 const void * rhs_ptr = static_cast <const void *>(rhs_packed + rhs_packed_offset);
152151 const void * lhs_ptr = (const void *)((const char *)lhs_packed + lhs_packed_offset);
153152 float *dst_ptr = reinterpret_cast <float *>(static_cast <uint8_t *>(dst->data ) + dst_offset);
154153
155- kernel->run_kernel (m, n_to_process, k, k_q4_0_block_size , lhs_ptr, rhs_ptr, dst_ptr,
154+ kernel->run_kernel (m, n_to_process, k, QK4_0 , lhs_ptr, rhs_ptr, dst_ptr,
156155 dst_stride, sizeof (float ), -FLT_MAX, FLT_MAX);
157156 return true ;
158157 }
@@ -169,13 +168,13 @@ class tensor_traits : public ggml::cpu::tensor_traits {
169168 size_t sr = ctx.kernels ->gemm .get_sr ();
170169
171170#ifndef NDEBUG
172- const size_t repacked_size = ctx.kernels ->rhs_info .packed_size (n, k, nr, kr, k_q4_0_block_size );
171+ const size_t repacked_size = ctx.kernels ->rhs_info .packed_size (n, k, nr, kr, QK4_0 );
173172 GGML_ASSERT (repacked_size <= data_size && " repacked size larger than the packed size!" );
174173#endif
175174 struct kai_rhs_pack_qs4cxs1s0_param params;
176175 params.lhs_zero_point = 1 ;
177176 params.rhs_zero_point = 8 ;
178- ctx.kernels ->rhs_info .pack_func (1 , n, k, nr, kr, sr, k_q4_0_block_size , (const uint8_t *)data, NULL , tensor->data , 0 , ¶ms);
177+ ctx.kernels ->rhs_info .pack_func (1 , n, k, nr, kr, sr, QK4_0 , (const uint8_t *)data, NULL , tensor->data , 0 , ¶ms);
179178
180179 return 0 ;
181180
0 commit comments