1111#include <alloc.h>
1212#include <barrier.h>
1313#include <seqread.h>
14+ #include <mutex_pool.h>
1415
1516#define PIM_KERNEL_DPU 1
1617#include "../ggml/include/ggml.h"
1718#define GGML_COMMON_DECL_C
1819#include "../ggml/src/ggml-common.h"
1920
2021#define PRINT 0
22+ #define SEGMENT_PER_ROW 4
23+
24+ // Find the lowest index for the rank-th group
25+ #define BLOCK_LOW (rank , size , n ) ((rank) * (n) / (size))
26+
27+ // Find the highest index for the rank-th group
28+ #define BLOCK_HIGH (rank , size , n ) (BLOCK_LOW((rank) + 1, (size), (n)) - 1)
2129
2230__mram_ptr float * ptable_f32_f16 ;
2331
@@ -35,6 +43,7 @@ inline static float lookup_fp16_to_fp32(uint16_t f) {
3543
3644// Barrier
3745BARRIER_INIT (my_barrier , NR_TASKLETS );
46+ MUTEX_POOL_INIT (g_psumf_mutex_pool , NR_TASKLETS );
3847
3948/*
4049DPU MRAM Memory:
@@ -91,8 +100,9 @@ int wram2mram(__mram_ptr void *pmram,void *pwram,uint32_t size)
91100}
92101
93102
94- // set psumf to global value for each thread access
95- static float * psumf = NULL ;
103+ // set g_psumf to global value for each thread access
104+ static float * g_psumf = NULL ;
105+ static block_q8_0 * g_pinput_cache = NULL ;
96106
97107void init (unsigned int tasklet_id ) {
98108#if PRINT
@@ -140,9 +150,11 @@ int main() {
140150#endif
141151
142152 // set sart line, end line and line number in each thread
143- uint16_t weight_rows_per_thread = cache_meta -> rows_per_dpu / NR_TASKLETS ;
144- uint16_t weight_start_row = tasklet_id * weight_rows_per_thread ;
145- uint16_t weight_end_row = weight_start_row + weight_rows_per_thread ;
153+ uint16_t segments_num = cache_meta -> rows_per_dpu * SEGMENT_PER_ROW ;
154+ uint16_t segment_start = BLOCK_LOW (tasklet_id , NR_TASKLETS , segments_num );
155+ uint16_t segment_end = BLOCK_HIGH (tasklet_id , NR_TASKLETS , segments_num );
156+
157+ assert (segment_start <= segment_end && "There are not enough segments to allocate to the tasklets" );
146158
147159 // todo:rest row is existed, first thread in every dpu can one more row
148160 uint16_t weight_rows_cur_thread ;
@@ -184,83 +196,80 @@ int main() {
184196 return -1 ;
185197 }
186198 int nb = pinputcache -> ne [0 ]/QK8_0 ;
199+
200+ assert (SEGMENT_PER_ROW <= nb && nb % SEGMENT_PER_ROW == 0
201+ && "Too many segments are allocated to each row." );
202+
187203 int qk = QK8_0 ;
188204 input_row_size = nb * sizeof (block_q8_0 );
189205 __mram_ptr void * pweight_base = (__mram_ptr void * )(weightmetadatabase + sizeof (struct pim_meta ));
190206 __mram_ptr void * pinput_base = DPU_MRAM_HEAP_POINTER + cache_meta -> input_offset + sizeof (pim_matrix_des );
191-
207+
192208 if (tasklet_id == 0 ) {
193- psumf = (float * )mem_alloc (sizeof (float )* input_cols * weight_rows_cur_thread );
209+ g_psumf = (float * )mem_alloc (sizeof (float )* input_cols * weight_rows_cur_thread );
210+ g_pinput_cache = (block_q8_0 * ) mem_alloc (sizeof (block_q8_0 ) * nb );
211+ memset (g_psumf , 0 ,sizeof (float )* input_cols * weight_rows_cur_thread );
194212 }
195- barrier_wait (& my_barrier );
196213
197- // psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
198- memset (psumf , 0 ,sizeof (float )* input_cols * weight_rows_cur_thread );
199-
200214#if PRINT
201215 printf ("input_cols=%d, rows_cur_thread=%d, nb=%d, input_row_size=%d\n" ,input_cols ,weight_rows_cur_thread ,nb ,input_row_size );
202216#endif
203- block_q4_0 * pweight_cache = (block_q4_0 * ) mem_alloc (sizeof (block_q4_0 )* nb );
204- block_q8_0 * pinput_cache = (block_q8_0 * ) mem_alloc (sizeof (block_q8_0 )* nb );
217+
218+ uint16_t segment_nb_size = nb / SEGMENT_PER_ROW ;
219+ block_q4_0 * pweight_cache = (block_q4_0 * ) mem_alloc (sizeof (block_q4_0 ) * segment_nb_size );
205220
206221 // weight_rows_cur_thread = 16;
207222 for (int l = 0 ;l < input_cols ;l ++ ) {
208- __mram_ptr block_q8_0 * pinput = pinput_base + l * nb * sizeof (block_q8_0 );
209- mram2wram (pinput , pinput_cache , sizeof (block_q8_0 )* nb );
210- #if PRINT
211- printf ("input:\n" );
212- for (int i = 0 ; i < nb ; i ++ ) {
213- printf ("d=%u\n" ,pinput [i ].d );
214- for (int kkk = 0 ;kkk < QK8_0 ;kkk ++ ) {
215- printf ("%d " ,pinput [i ].qs [kkk ]);
216- }
217- printf ("\n" );
223+ if (tasklet_id == 0 ) {
224+ __mram_ptr block_q8_0 * pinput = pinput_base + l * nb * sizeof (block_q8_0 );
225+ mram2wram (pinput , g_pinput_cache , sizeof (block_q8_0 )* nb );
218226 }
219- printf ("pweight_base: %p\n" , pweight_base );
220- #endif
221- // for(int k = 0;k < weight_rows_cur_thread;k++) {
222- for (int k = weight_start_row ; k < weight_end_row ; ++ k ) {
223- __mram_ptr block_q4_0 * pweight = pweight_base + pinputcache -> layerid * cache_meta -> layer_len + k * nb * sizeof (block_q4_0 );
224- mram2wram (pweight , pweight_cache , sizeof (block_q4_0 )* nb );
225- #if PRINT
226- if (k % 64 == 0 ) {
227- printf ("pweight_cache[%d].d=%d\n pweight_cache[%d].qs=" , k * 128 , pweight_cache [0 ].d , k * 128 );
228- for (int kkk = 0 ;kkk < QK4_0 /2 ;kkk ++ ) {
229- int v0 = (pweight_cache [0 ].qs [kkk ] & 0x0f ) - 8 ;
230- int v1 = (pweight_cache [0 ].qs [kkk ] >> 4 ) - 8 ;
231- printf (" %d, %d" , v0 , v1 );
232- }
233- printf ("\n" );
234- }
235- #endif
236227
237- for (int i = 0 ; i < nb ; i ++ ) {
238- //printf("input_col:%d, current inner weight row idx:%d\n",l,k);
228+ barrier_wait (& my_barrier );
229+
230+ __mram_ptr block_q4_0 * pweight_addr = pweight_base + pinputcache -> layerid * cache_meta -> layer_len ;
239231
232+ for (int k = segment_start ; k <= segment_end ; ++ k ) {
233+ __mram_ptr block_q4_0 * pweight = pweight_addr + k * segment_nb_size ;
234+ mram2wram (pweight , pweight_cache , sizeof (block_q4_0 ) * segment_nb_size );
235+
236+ block_q8_0 * pinput_cache = g_pinput_cache + k % SEGMENT_PER_ROW * segment_nb_size ;
237+
238+ for (int i = 0 ; i < segment_nb_size ; i ++ ) {
240239 int sumi = 0 ;
241240 for (int j = 0 ; j < qk /2 ; ++ j ) {
242241 const int v0 = (pweight_cache [i ].qs [j ] & 0x0F ) - 8 ;
243242 const int v1 = (pweight_cache [i ].qs [j ] >> 4 ) - 8 ;
244243
245244 sumi += (v0 * pinput_cache [i ].qs [j ]) + (v1 * pinput_cache [i ].qs [j + qk /2 ]);
246245 }
247-
248- psumf [l * weight_rows_cur_thread + k ] += sumi * FP16_TO_FP32 (pweight_cache [i ].d )* FP16_TO_FP32 (pinput_cache [i ].d );
246+
247+ int psumf_idx = l * weight_rows_cur_thread + k / SEGMENT_PER_ROW ;
248+ float sum = sumi * FP16_TO_FP32 (pweight_cache [i ].d ) * FP16_TO_FP32 (pinput_cache [i ].d );
249+ mutex_pool_lock (& g_psumf_mutex_pool , psumf_idx );
250+ g_psumf [psumf_idx ] += sum ;
251+ // g_psumf[psumf_idx] += sumi;
252+ mutex_pool_unlock (& g_psumf_mutex_pool , psumf_idx );
249253 }
250254 }
251255 }
252256 }
253257
254- offset += (sizeof (pim_matrix_des ) + input_row_size * input_cols );
255- #if PRINT
256- for (int iii = 0 ;iii < cache_meta -> rows_per_dpu ;iii += 128 ) {
257- printf ("psumf[%d]=%f\n" ,iii ,psumf [iii ]);
258+ barrier_wait (& my_barrier );
259+
260+ if (tasklet_id == 0 ){
261+ offset += (sizeof (pim_matrix_des ) + input_row_size * input_cols );
262+ #if PRINT
263+ for (int iii = 0 ;iii < cache_meta -> rows_per_dpu ;iii += 128 ) {
264+ printf ("g_psumf[%d]=%f\n" ,iii ,g_psumf [iii ]);
265+ }
266+
267+ printf ("output offset=%d\n" ,offset );
268+ #endif
269+ // Write C Matrix to current MRAM block
270+ // Note: with input_cols > 1, the results should be rearranged on host
271+ wram2mram ((__mram_ptr void * ) (DPU_MRAM_HEAP_POINTER + offset ), g_psumf , sizeof (float )* input_cols * weight_rows_cur_thread );
258272 }
259273
260- printf ("output offset=%d\n" ,offset );
261- #endif
262- // Write C Matrix to current MRAM block
263- // Note: with input_cols > 1, the results should be rearranged on host
264- wram2mram ((__mram_ptr void * ) (DPU_MRAM_HEAP_POINTER + offset ), psumf , sizeof (float )* input_cols * weight_rows_cur_thread );
265274 return 0 ;
266275}
0 commit comments