1111#include <alloc.h>
1212#include <barrier.h>
1313#include <seqread.h>
14+ #include <mutex_pool.h>
1415
1516#define PIM_KERNEL_DPU 1
1617#include "../ggml/include/ggml.h"
1718#define GGML_COMMON_DECL_C
1819#include "../ggml/src/ggml-common.h"
1920
2021#define PRINT 0
22+ #define SEGMENT_PER_ROW 4
23+
24+ // Find the lowest index for the rank-th group
25+ #define BLOCK_LOW (rank , size , n ) ((rank) * (n) / (size))
26+
27+ // Find the highest index for the rank-th group
28+ #define BLOCK_HIGH (rank , size , n ) (BLOCK_LOW((rank) + 1, (size), (n)) - 1)
2129
2230__mram_ptr float * ptable_f32_f16 ;
2331
32+ __host int16_t mul_table_int4_int8 [1 <<4 ][1 <<8 ];
33+
2434inline static float lookup_fp16_to_fp32 (uint16_t f ) {
2535 uint16_t s ;
2636 memcpy (& s , & f , sizeof (uint16_t ));
@@ -35,6 +45,7 @@ inline static float lookup_fp16_to_fp32(uint16_t f) {
3545
3646// Barrier
3747BARRIER_INIT (my_barrier , NR_TASKLETS );
48+ MUTEX_POOL_INIT (g_psumf_mutex_pool , NR_TASKLETS );
3849
3950/*
4051DPU MRAM Memory:
@@ -91,8 +102,9 @@ int wram2mram(__mram_ptr void *pmram,void *pwram,uint32_t size)
91102}
92103
93104
94- // set psumf to global value for each thread access
95- static float * psumf = NULL ;
105+ // set g_psumf to global value for each thread access
106+ static float * g_psumf = NULL ;
107+ static block_q8_0 * g_pinput_cache = NULL ;
96108
97109void init (unsigned int tasklet_id ) {
98110#if PRINT
@@ -140,9 +152,11 @@ int main() {
140152#endif
141153
142154 // set sart line, end line and line number in each thread
143- uint16_t weight_rows_per_thread = cache_meta -> rows_per_dpu / NR_TASKLETS ;
144- uint16_t weight_start_row = tasklet_id * weight_rows_per_thread ;
145- uint16_t weight_end_row = weight_start_row + weight_rows_per_thread ;
155+ uint16_t segments_num = cache_meta -> rows_per_dpu * SEGMENT_PER_ROW ;
156+ uint16_t segment_start = BLOCK_LOW (tasklet_id , NR_TASKLETS , segments_num );
157+ uint16_t segment_end = BLOCK_HIGH (tasklet_id , NR_TASKLETS , segments_num );
158+
159+ assert (segment_start <= segment_end && "There are not enough segments to allocate to the tasklets" );
146160
147161 // todo:rest row is existed, first thread in every dpu can one more row
148162 uint16_t weight_rows_cur_thread ;
@@ -184,83 +198,82 @@ int main() {
184198 return -1 ;
185199 }
186200 int nb = pinputcache -> ne [0 ]/QK8_0 ;
201+
202+ assert (SEGMENT_PER_ROW <= nb && nb % SEGMENT_PER_ROW == 0
203+ && "Too many segments are allocated to each row." );
204+
187205 int qk = QK8_0 ;
188206 input_row_size = nb * sizeof (block_q8_0 );
189207 __mram_ptr void * pweight_base = (__mram_ptr void * )(weightmetadatabase + sizeof (struct pim_meta ));
190208 __mram_ptr void * pinput_base = DPU_MRAM_HEAP_POINTER + cache_meta -> input_offset + sizeof (pim_matrix_des );
191-
209+
192210 if (tasklet_id == 0 ) {
193- psumf = (float * )mem_alloc (sizeof (float )* input_cols * weight_rows_cur_thread );
211+ g_psumf = (float * )mem_alloc (sizeof (float )* input_cols * weight_rows_cur_thread );
212+ g_pinput_cache = (block_q8_0 * ) mem_alloc (sizeof (block_q8_0 ) * nb );
213+ memset (g_psumf , 0 ,sizeof (float )* input_cols * weight_rows_cur_thread );
194214 }
195- barrier_wait (& my_barrier );
196215
197- // psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
198- memset (psumf , 0 ,sizeof (float )* input_cols * weight_rows_cur_thread );
199-
200216#if PRINT
201217 printf ("input_cols=%d, rows_cur_thread=%d, nb=%d, input_row_size=%d\n" ,input_cols ,weight_rows_cur_thread ,nb ,input_row_size );
202218#endif
203- block_q4_0 * pweight_cache = (block_q4_0 * ) mem_alloc (sizeof (block_q4_0 )* nb );
204- block_q8_0 * pinput_cache = (block_q8_0 * ) mem_alloc (sizeof (block_q8_0 )* nb );
219+
220+ uint16_t segment_nb_size = nb / SEGMENT_PER_ROW ;
221+ block_q4_0 * pweight_cache = (block_q4_0 * ) mem_alloc (sizeof (block_q4_0 ) * segment_nb_size );
205222
206223 // weight_rows_cur_thread = 16;
207224 for (int l = 0 ;l < input_cols ;l ++ ) {
208- __mram_ptr block_q8_0 * pinput = pinput_base + l * nb * sizeof (block_q8_0 );
209- mram2wram (pinput , pinput_cache , sizeof (block_q8_0 )* nb );
210- #if PRINT
211- printf ("input:\n" );
212- for (int i = 0 ; i < nb ; i ++ ) {
213- printf ("d=%u\n" ,pinput [i ].d );
214- for (int kkk = 0 ;kkk < QK8_0 ;kkk ++ ) {
215- printf ("%d " ,pinput [i ].qs [kkk ]);
216- }
217- printf ("\n" );
225+ if (tasklet_id == 0 ) {
226+ __mram_ptr block_q8_0 * pinput = pinput_base + l * nb * sizeof (block_q8_0 );
227+ mram2wram (pinput , g_pinput_cache , sizeof (block_q8_0 )* nb );
218228 }
219- printf ("pweight_base: %p\n" , pweight_base );
220- #endif
221- // for(int k = 0;k < weight_rows_cur_thread;k++) {
222- for (int k = weight_start_row ; k < weight_end_row ; ++ k ) {
223- __mram_ptr block_q4_0 * pweight = pweight_base + pinputcache -> layerid * cache_meta -> layer_len + k * nb * sizeof (block_q4_0 );
224- mram2wram (pweight , pweight_cache , sizeof (block_q4_0 )* nb );
225- #if PRINT
226- if (k % 64 == 0 ) {
227- printf ("pweight_cache[%d].d=%d\n pweight_cache[%d].qs=" , k * 128 , pweight_cache [0 ].d , k * 128 );
228- for (int kkk = 0 ;kkk < QK4_0 /2 ;kkk ++ ) {
229- int v0 = (pweight_cache [0 ].qs [kkk ] & 0x0f ) - 8 ;
230- int v1 = (pweight_cache [0 ].qs [kkk ] >> 4 ) - 8 ;
231- printf (" %d, %d" , v0 , v1 );
232- }
233- printf ("\n" );
234- }
235- #endif
236229
237- for (int i = 0 ; i < nb ; i ++ ) {
238- //printf("input_col:%d, current inner weight row idx:%d\n",l,k);
230+ barrier_wait (& my_barrier );
231+
232+ __mram_ptr block_q4_0 * pweight_addr = pweight_base + pinputcache -> layerid * cache_meta -> layer_len ;
239233
234+ for (int k = segment_start ; k <= segment_end ; ++ k ) {
235+ __mram_ptr block_q4_0 * pweight = pweight_addr + k * segment_nb_size ;
236+ mram2wram (pweight , pweight_cache , sizeof (block_q4_0 ) * segment_nb_size );
237+
238+ block_q8_0 * pinput_cache = g_pinput_cache + k % SEGMENT_PER_ROW * segment_nb_size ;
239+
240+ for (int i = 0 ; i < segment_nb_size ; i ++ ) {
240241 int sumi = 0 ;
241242 for (int j = 0 ; j < qk /2 ; ++ j ) {
242- const int v0 = (pweight_cache [i ].qs [j ] & 0x0F ) - 8 ;
243- const int v1 = (pweight_cache [i ].qs [j ] >> 4 ) - 8 ;
243+ const int8_t v0 = (pweight_cache [i ].qs [j ] & 0x0F ) - 8 ;
244+ const int8_t v1 = (pweight_cache [i ].qs [j ] >> 4 ) - 8 ;
244245
245- sumi += (v0 * pinput_cache [i ].qs [j ]) + (v1 * pinput_cache [i ].qs [j + qk /2 ]);
246+ // sumi += (v0 * pinput_cache[i].qs[j]) + (v1 * pinput_cache[i].qs[j + qk/2]);
247+ sumi += mul_table_int4_int8 [v0 + 8 ][pinput_cache [i ].qs [j ] - INT8_MIN ] +
248+ mul_table_int4_int8 [v1 + 8 ][pinput_cache [i ].qs [j + qk /2 ] - INT8_MIN ];
246249 }
247-
248- psumf [l * weight_rows_cur_thread + k ] += sumi * FP16_TO_FP32 (pweight_cache [i ].d )* FP16_TO_FP32 (pinput_cache [i ].d );
250+
251+ int psumf_idx = l * weight_rows_cur_thread + k / SEGMENT_PER_ROW ;
252+ float sum = sumi * FP16_TO_FP32 (pweight_cache [i ].d ) * FP16_TO_FP32 (pinput_cache [i ].d );
253+ mutex_pool_lock (& g_psumf_mutex_pool , psumf_idx );
254+ g_psumf [psumf_idx ] += sum ;
255+ // g_psumf[psumf_idx] += sumi;
256+ mutex_pool_unlock (& g_psumf_mutex_pool , psumf_idx );
249257 }
250258 }
251259 }
252260 }
253261
254- offset += (sizeof (pim_matrix_des ) + input_row_size * input_cols );
255- #if PRINT
256- for (int iii = 0 ;iii < cache_meta -> rows_per_dpu ;iii += 128 ) {
257- printf ("psumf[%d]=%f\n" ,iii ,psumf [iii ]);
262+ barrier_wait (& my_barrier );
263+
264+ if (tasklet_id == 0 ){
265+ offset += (sizeof (pim_matrix_des ) + input_row_size * input_cols );
266+ #if PRINT
267+ for (int iii = 0 ;iii < cache_meta -> rows_per_dpu ;iii += 128 ) {
268+ printf ("g_psumf[%d]=%f\n" ,iii ,g_psumf [iii ]);
269+ }
270+
271+ printf ("output offset=%d\n" ,offset );
272+ #endif
273+ // Write C Matrix to current MRAM block
274+ // Note: with input_cols > 1, the results should be rearranged on host
275+ wram2mram ((__mram_ptr void * ) (DPU_MRAM_HEAP_POINTER + offset ), g_psumf , sizeof (float )* input_cols * weight_rows_cur_thread );
258276 }
259277
260- printf ("output offset=%d\n" ,offset );
261- #endif
262- // Write C Matrix to current MRAM block
263- // Note: with input_cols > 1, the results should be rearranged on host
264- wram2mram ((__mram_ptr void * ) (DPU_MRAM_HEAP_POINTER + offset ), psumf , sizeof (float )* input_cols * weight_rows_cur_thread );
265278 return 0 ;
266279}
0 commit comments