@@ -107,8 +107,9 @@ int main() {
107107 ptable_f32_f16 = (__mram_ptr float * )DPU_MRAM_HEAP_POINTER ;
108108 uint32_t table_f32_f16_len = (1 << 16 )* sizeof (float );
109109 uint32_t offset = table_f32_f16_len ;
110- int input_row_size ,input_cols ;
111- float * psumf ;
110+ int input_row_size = 0 ;
111+ int input_cols = 0 ;
112+ float * psumf = NULL ;
112113
113114#if PRINT
114115 printf ("table_f32_f16_len=%d\n" ,table_f32_f16_len );
@@ -124,7 +125,7 @@ int main() {
124125 mram_read ((__mram_ptr void const * ) (weightmetadatabase ), cache_meta , sizeof (struct pim_meta ));
125126
126127#if PRINT
127- printf ("layer_num: %d, weight_type=%d,rows_per_dpu=%d,rest_rows=%d,input_offset=%d" ,
128+ printf ("layer_num: %d, weight_type=%d, rows_per_dpu=%d, rest_rows=%d, input_offset=%d" ,
128129 cache_meta -> layer_num ,cache_meta -> weight_type ,cache_meta -> rows_per_dpu ,cache_meta -> rest_rows ,cache_meta -> input_offset );
129130#endif
130131
@@ -142,14 +143,15 @@ int main() {
142143 //input metadata
143144 offset += (cache_meta -> layer_len * cache_meta -> layer_num );
144145#if PRINT
145- printf ("layer_len=%d,offset=%d\n" ,cache_meta -> layer_len ,offset );
146+ printf ("layer_len=%d, input metadata offset=%d\n" ,cache_meta -> layer_len ,offset );
146147#endif
147148 uint32_t inputmetadatabase = weightmetadatabase + sizeof (struct pim_meta ) + cache_meta -> layer_len * cache_meta -> layer_num ;
148149 pim_matrix_des * pinputcache = (pim_matrix_des * ) mem_alloc (sizeof (pim_matrix_des ));
149150 mram_read ((__mram_ptr void const * ) (inputmetadatabase ), pinputcache , sizeof (pim_matrix_des ));
150151 input_cols = pinputcache -> ne [1 ];
152+ assert (input_cols == 1 && "Only support vector as input." );
151153#if PRINT
152- printf ("input_type=%d,layerID=%d\n" ,pinputcache -> type ,pinputcache -> layerid );
154+ printf ("input_type=%d, layerID=%d\n" ,pinputcache -> type ,pinputcache -> layerid );
153155 for (int nn = 0 ;nn < GGML_MAX_DIMS ;nn ++ ) {
154156 printf ("ne[%d]=%lld\n" ,nn ,pinputcache -> ne [nn ]);
155157 }
@@ -165,19 +167,19 @@ int main() {
165167 int nb = pinputcache -> ne [0 ]/QK8_0 ;
166168 int qk = QK8_0 ;
167169 input_row_size = nb * sizeof (block_q8_0 );
168- __mram_ptr block_q4_0 * pweight_base = (__mram_ptr block_q4_0 * )(weightmetadatabase + sizeof (struct pim_meta ));
169- __mram_ptr block_q8_0 * pinput_base = ( __mram_ptr block_q8_0 * )( DPU_MRAM_HEAP_POINTER + cache_meta -> input_offset + sizeof (pim_matrix_des ) );
170+ __mram_ptr void * pweight_base = (__mram_ptr void * )(weightmetadatabase + sizeof (struct pim_meta ));
171+ __mram_ptr void * pinput_base = DPU_MRAM_HEAP_POINTER + cache_meta -> input_offset + sizeof (pim_matrix_des );
170172 psumf = (float * )mem_alloc (sizeof (float )* input_cols * weight_rows_cur_thread );
171173 memset (psumf , 0 ,sizeof (float )* input_cols * weight_rows_cur_thread );
172174#if PRINT
173- printf ("input_cols=%d,rows_cur_thread=%d,nb=%d,input_row_size=%d\n" ,input_cols ,weight_rows_cur_thread ,nb ,input_row_size );
175+ printf ("input_cols=%d, rows_cur_thread=%d, nb=%d, input_row_size=%d\n" ,input_cols ,weight_rows_cur_thread ,nb ,input_row_size );
174176#endif
175177 block_q4_0 * pweight_cache = (block_q4_0 * ) mem_alloc (sizeof (block_q4_0 )* nb );
176178 block_q8_0 * pinput_cache = (block_q8_0 * ) mem_alloc (sizeof (block_q8_0 )* nb );
177179
178180 // weight_rows_cur_thread = 16;
179181 for (int l = 0 ;l < input_cols ;l ++ ) {
180- __mram_ptr block_q8_0 * pinput = pinput_base + l * nb ;
182+ __mram_ptr block_q8_0 * pinput = pinput_base + l * nb * sizeof ( block_q8_0 ) ;
181183 mram2wram (pinput , pinput_cache , sizeof (block_q8_0 )* nb );
182184#if PRINT
183185 printf ("input:\n" );
@@ -191,8 +193,7 @@ int main() {
191193 printf ("pweight_base: %p\n" , pweight_base );
192194#endif
193195 for (int k = 0 ;k < weight_rows_cur_thread ;k ++ ) {
194- //block_q4_0 *pqlayer0weight = (block_q4_0 *)(weightmetadatabase + sizeof(struct pim_meta) + cache_meta->layer_len*k);
195- __mram_ptr block_q4_0 * pweight = pweight_base + pinputcache -> layerid * cache_meta -> layer_len + k * nb ;
196+ __mram_ptr block_q4_0 * pweight = pweight_base + pinputcache -> layerid * cache_meta -> layer_len + k * nb * sizeof (block_q4_0 );
196197 mram2wram (pweight , pweight_cache , sizeof (block_q4_0 )* nb );
197198#if PRINT
198199 if (k % 64 == 0 ) {
@@ -207,11 +208,10 @@ int main() {
207208#endif
208209
209210 for (int i = 0 ; i < nb ; i ++ ) {
210- //printf("input_col:%d,weight_row :%d\n",l,k);
211+ //printf("input_col:%d, current inner weight row idx :%d\n",l,k);
211212
212213 int sumi = 0 ;
213214 for (int j = 0 ; j < qk /2 ; ++ j ) {
214- //printf("nb:%d,qk=%d,qs=%d\n",i,j,pweight_cache[i].qs[j]);
215215 const int v0 = (pweight_cache [i ].qs [j ] & 0x0F ) - 8 ;
216216 const int v1 = (pweight_cache [i ].qs [j ] >> 4 ) - 8 ;
217217
@@ -230,9 +230,10 @@ int main() {
230230 printf ("psumf[%d]=%f\n" ,iii ,psumf [iii ]);
231231 }
232232
233- printf ("offset=%d\n" ,offset );
233+ printf ("output offset=%d\n" ,offset );
234234#endif
235235 // Write C Matrix to current MRAM block
236- wram2mram ((__mram_ptr void * ) (DPU_MRAM_HEAP_POINTER + offset ),psumf ,sizeof (float )* input_cols * weight_rows_cur_thread );
236+ // Note: with input_cols > 1, the results should be rearranged on host
237+ wram2mram ((__mram_ptr void * ) (DPU_MRAM_HEAP_POINTER + offset ), psumf , sizeof (float )* input_cols * weight_rows_cur_thread );
237238 return 0 ;
238239}
0 commit comments