@@ -114,30 +114,37 @@ class tensor_traits : public ggml::cpu::tensor_traits {
114114 size_t sr = kernel->get_sr ();
115115 size_t bl = k_q4_0_block_size;
116116
117- const size_t lhs_packed_offset = lhs_info->get_packed_offset (0 , k, bl, mr, kr, sr);
117+ // Calculate number of columns to be processed per thread
118+ const size_t num_m_per_thread = kai_roundup (m, nth) / nth;
119+ const size_t m_start = ith * num_m_per_thread;
120+ size_t m_to_process = num_m_per_thread;
121+ if ((m_start + m_to_process) > m) {
122+ m_to_process = m - m_start;
123+ }
118124
119- if (ith == 0 ) {
125+ if (m_start < m ) {
120126 // Transform LHS
121- const size_t src_stride = src1->nb [1 ];
122- const float * src_ptr = reinterpret_cast <const float *>(lhs + lhs_info->get_offset (0 , dst->src [1 ]->nb [1 ]));
123- void * dst_ptr = static_cast <void *>(lhs_packed + lhs_packed_offset);
127+ const size_t src_stride = src1->nb [1 ];
128+ const float * src_ptr = reinterpret_cast <const float *>(lhs + lhs_info->get_offset (0 , dst->src [1 ]->nb [1 ]));
129+ const size_t lhs_packed_offset = lhs_info->get_packed_offset (m_start, k, bl, mr, kr, sr);
130+ void * lhs_packed_ptr = static_cast <void *>(lhs_packed + lhs_packed_offset);
124131
125- lhs_info->pack_func (m , k, bl, mr, kr, sr, 0 , src_ptr, src_stride, dst_ptr );
132+ lhs_info->pack_func (m_to_process , k, bl, mr, kr, sr, m_start , src_ptr, src_stride, lhs_packed_ptr );
126133 }
127134
128135 ggml_barrier (params->threadpool );
129- // Perform the operation
130- const size_t dst_stride = dst->nb [1 ];
131136
137+ // Perform the operation
138+ const size_t dst_stride = dst->nb [1 ];
139+ const size_t lhs_packed_offset = lhs_info->get_packed_offset (0 , k, k_q4_0_block_size, mr, kr, sr);
132140 const size_t rhs_packed_offset = kernel->get_rhs_packed_offset (n_start, k, k_q4_0_block_size);
133141 const size_t dst_offset = kernel->get_dst_offset (0 , n_start, dst_stride);
134-
135- const void * lhs_ptr = static_cast <const void *>(lhs_packed + lhs_packed_offset);
136- const void * rhs_ptr = static_cast <const void *>(rhs_packed + rhs_packed_offset);
137- float *dst_ptr = reinterpret_cast <float *>(static_cast <uint8_t *>(dst->data ) + dst_offset);
142+ const void * rhs_ptr = static_cast <const void *>(rhs_packed + rhs_packed_offset);
143+ const void * lhs_ptr = (const void *)((const char *)lhs_packed + lhs_packed_offset);
144+ float *dst_ptr = reinterpret_cast <float *>(static_cast <uint8_t *>(dst->data ) + dst_offset);
138145
139146 kernel->run_kernel (m, n_to_process, k, k_q4_0_block_size, lhs_ptr, rhs_ptr, dst_ptr,
140- dst_stride, sizeof (float ), -FLT_MAX, FLT_MAX);
147+ dst_stride, sizeof (float ), -FLT_MAX, FLT_MAX);
141148 return true ;
142149 }
143150 return false ;
0 commit comments