@@ -91,6 +91,7 @@ int wram2mram(__mram_ptr void *pmram,void *pwram,uint32_t size)
9191}
9292
9393
94+ // set psumf to global value for each thread access
9495static float * psumf = NULL ;
9596
9697void init (unsigned int tasklet_id ) {
@@ -99,13 +100,11 @@ void init(unsigned int tasklet_id) {
99100#endif
100101 if (tasklet_id == 0 ){ // Initialize once the cycle counter
101102 mem_reset (); // Reset the heap
102-
103+ // first thread set fp32->fp16 table
103104 ptable_f32_f16 = (__mram_ptr float * )DPU_MRAM_HEAP_POINTER ;
104105 }
105106 // Barrier
106107 barrier_wait (& my_barrier );
107-
108- // ptable_f32_f16 = (__mram_ptr float *)DPU_MRAM_HEAP_POINTER;
109108}
110109
111110// main
@@ -115,8 +114,7 @@ int main() {
115114
116115 init (tasklet_id );
117116
118- //fp32->fp16 table
119- ptable_f32_f16 = (__mram_ptr float * )DPU_MRAM_HEAP_POINTER ;
117+ //set fp32->fp16 table configure
120118 uint32_t table_f32_f16_len = (1 << 16 )* sizeof (float );
121119 uint32_t offset = table_f32_f16_len ;
122120 int input_row_size = 0 ;
@@ -141,7 +139,7 @@ int main() {
141139 cache_meta -> layer_num ,cache_meta -> weight_type ,cache_meta -> rows_per_dpu ,cache_meta -> rest_rows ,cache_meta -> input_offset );
142140#endif
143141
144- // 先不考虑尾行
142+ // set sart line, end line and line number in each thread
145143 uint16_t weight_rows_per_thread = cache_meta -> rows_per_dpu / NR_TASKLETS ;
146144 uint16_t weight_start_row = tasklet_id * weight_rows_per_thread ;
147145 uint16_t weight_end_row = weight_start_row + weight_rows_per_thread ;
@@ -159,14 +157,17 @@ int main() {
159157
160158 //input metadata
161159 offset += (cache_meta -> layer_len * cache_meta -> layer_num );
160+
162161#if PRINT
163162 printf ("layer_len=%d, input metadata offset=%d\n" ,cache_meta -> layer_len ,offset );
164163#endif
164+
165165 uint32_t inputmetadatabase = weightmetadatabase + sizeof (struct pim_meta ) + cache_meta -> layer_len * cache_meta -> layer_num ;
166166 pim_matrix_des * pinputcache = (pim_matrix_des * ) mem_alloc (sizeof (pim_matrix_des ));
167167 mram_read ((__mram_ptr void const * ) (inputmetadatabase ), pinputcache , sizeof (pim_matrix_des ));
168168 input_cols = pinputcache -> ne [1 ];
169169 assert (input_cols == 1 && "Only support vector as input." );
170+
170171#if PRINT
171172 printf ("input_type=%d, layerID=%d\n" ,pinputcache -> type ,pinputcache -> layerid );
172173 for (int nn = 0 ;nn < GGML_MAX_DIMS ;nn ++ ) {
@@ -175,6 +176,7 @@ int main() {
175176#endif
176177
177178 assert (cache_meta -> weight_type == ((uint16_t )GGML_TYPE_Q4_0 ) && "Only support Q4_0 weight." );
179+
178180 //weight info: GGML_TYPE_Q4_0 default
179181 if (cache_meta -> weight_type == ((uint16_t )GGML_TYPE_Q4_0 )) {
180182 if (pinputcache -> type != GGML_TYPE_Q8_0 ) {
@@ -194,6 +196,7 @@ int main() {
194196
195197 // psumf = (float *)mem_alloc(sizeof(float)*input_cols*weight_rows_cur_thread);
196198 memset (psumf , 0 ,sizeof (float )* input_cols * weight_rows_cur_thread );
199+
197200#if PRINT
198201 printf ("input_cols=%d, rows_cur_thread=%d, nb=%d, input_row_size=%d\n" ,input_cols ,weight_rows_cur_thread ,nb ,input_row_size );
199202#endif
0 commit comments