Skip to content

Commit f4eb1b3

Browse files
committed
Add support for multithread LHS conversion
1 parent 119d3bf commit f4eb1b3

File tree

2 files changed

+26
-19
lines changed

2 files changed

+26
-19
lines changed

ggml/src/ggml-cpu/CMakeLists.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
117117
)
118118
if (GGML_MACHINE_SUPPORTS_${tag})
119119
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+${tag}" PARENT_SCOPE)
120-
else()
120+
elseif(NOT tag STREQUAL "sme")
121121
set(ARM_MCPU_FLAG_FIX "${ARM_MCPU_FLAG_FIX}+no${tag}" PARENT_SCOPE)
122122
endif()
123123
set(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
@@ -325,9 +325,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
325325

326326
# Fetch KleidiAI sources:
327327
include(FetchContent)
328-
set(KLEIDIAI_COMMIT_TAG "v1.2.0")
328+
set(KLEIDIAI_COMMIT_TAG "v1.3.0")
329329
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
330-
set(KLEIDIAI_ARCHIVE_MD5 "6634fefce7357ecfee9eace2068bc68b")
330+
set(KLEIDIAI_ARCHIVE_MD5 "060bd2dc64642b091f461cc8dd7426d9")
331331

332332
if (POLICY CMP0135)
333333
cmake_policy(SET CMP0135 NEW)
@@ -370,9 +370,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
370370
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/
371371
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/)
372372

373-
string(FIND ${ARCH_FLAGS} "+dotprod" DOTPROD_ENABLED)
374-
string(FIND ${ARCH_FLAGS} "+i8mm" I8MM_ENABLED)
375-
string(FIND ${ARCH_FLAGS} "+sme" SME_ENABLED)
373+
string(FIND "${ARCH_FLAGS}" "+dotprod" DOTPROD_ENABLED)
374+
string(FIND "${ARCH_FLAGS}" "+i8mm" I8MM_ENABLED)
375+
string(FIND "${ARCH_FLAGS}" "+sme" SME_ENABLED)
376376

377377
set(PRIVATE_ARCH_FLAGS ${ARCH_FLAGS})
378378

ggml/src/ggml-cpu/ggml-kleidiai/ggml-kleidiai.cpp

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -114,30 +114,37 @@ class tensor_traits : public ggml::cpu::tensor_traits {
114114
size_t sr = kernel->get_sr();
115115
size_t bl = k_q4_0_block_size;
116116

117-
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, bl, mr, kr, sr);
117+
// Calculate number of columns to be processed per thread
118+
const size_t num_m_per_thread = kai_roundup(m, nth) / nth;
119+
const size_t m_start = ith * num_m_per_thread;
120+
size_t m_to_process = num_m_per_thread;
121+
if ((m_start + m_to_process) > m) {
122+
m_to_process = m - m_start;
123+
}
118124

119-
if (ith == 0) {
125+
if(m_start < m) {
120126
// Transform LHS
121-
const size_t src_stride = src1->nb[1];
122-
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
123-
void * dst_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
127+
const size_t src_stride = src1->nb[1];
128+
const float * src_ptr = reinterpret_cast<const float *>(lhs + lhs_info->get_offset(0, dst->src[1]->nb[1]));
129+
const size_t lhs_packed_offset = lhs_info->get_packed_offset(m_start, k, bl, mr, kr, sr);
130+
void * lhs_packed_ptr = static_cast<void *>(lhs_packed + lhs_packed_offset);
124131

125-
lhs_info->pack_func(m, k, bl, mr, kr, sr, 0, src_ptr, src_stride, dst_ptr);
132+
lhs_info->pack_func(m_to_process, k, bl, mr, kr, sr, m_start, src_ptr, src_stride, lhs_packed_ptr);
126133
}
127134

128135
ggml_barrier(params->threadpool);
129-
// Perform the operation
130-
const size_t dst_stride = dst->nb[1];
131136

137+
// Perform the operation
138+
const size_t dst_stride = dst->nb[1];
139+
const size_t lhs_packed_offset = lhs_info->get_packed_offset(0, k, k_q4_0_block_size, mr, kr, sr);
132140
const size_t rhs_packed_offset = kernel->get_rhs_packed_offset(n_start, k, k_q4_0_block_size);
133141
const size_t dst_offset = kernel->get_dst_offset(0, n_start, dst_stride);
134-
135-
const void * lhs_ptr = static_cast<const void *>(lhs_packed + lhs_packed_offset);
136-
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
137-
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
142+
const void * rhs_ptr = static_cast<const void *>(rhs_packed + rhs_packed_offset);
143+
const void* lhs_ptr = (const void*)((const char *)lhs_packed + lhs_packed_offset);
144+
float *dst_ptr = reinterpret_cast<float *>(static_cast<uint8_t *>(dst->data) + dst_offset);
138145

139146
kernel->run_kernel(m, n_to_process, k, k_q4_0_block_size, lhs_ptr, rhs_ptr, dst_ptr,
140-
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
147+
dst_stride, sizeof(float), -FLT_MAX, FLT_MAX);
141148
return true;
142149
}
143150
return false;

0 commit comments

Comments
 (0)