@@ -36,8 +36,8 @@ struct test_model {
3636
3737void load_model (test_model & model, bool use_gpu = false ) {
3838 // create data
39- int KW = 3 , KH = 3 , IC = 10 , OC = 10 ;
40- int IW = 8 , IH = 6 , N = 1 ;
39+ int KW = 3 , KH = 3 , IC = 32 , OC = 64 ;
40+ int IW = 28 , IH = 40 , N = 1 ;
4141
4242 // Initialize adata
4343 std::vector<float > adata (KW * KH * IC * OC);
@@ -157,16 +157,21 @@ struct ggml_cgraph * build_graph(const test_model& model) {
157157 int d0 = 1 ;
158158 int d1 = 1 ;
159159
160- // split conv2d in fundamental methods for test unit
161- struct ggml_tensor * im2col_0 = ggml_im2col (ctx0, model.a , model.b , s0, s1, p0, p1, d0, d1, true , GGML_TYPE_F16);
162- ggml_set_name (im2col_0, " im2col_res" );
163- ggml_build_forward_expand (gf, im2col_0);
160+
164161
165162 // recalculate for avoid fragmentation
166163 struct ggml_tensor * conv2d_res = ggml_conv_2d (ctx0, model.a , model.b , s0, s1, p0, p1, d0, d1);
167164 ggml_set_name (conv2d_res, " conv2d_res" );
168165 ggml_build_forward_expand (gf, conv2d_res);
166+ int64_t *ne = conv2d_res->ne ;
167+ printf (" conv2d: (%zu, %zu, %zu, %zu) \n " , ne[0 ], ne[1 ], ne[2 ], ne[3 ]);
169168
169+
170+ struct ggml_tensor * wino_res = ggml_conv_2d_3x3 (ctx0, model.a , model.b );
171+ ggml_set_name (wino_res, " wino_res" );
172+ ggml_build_forward_expand (gf, wino_res);
173+ ne = wino_res->ne ;
174+ printf (" wino: (%zu, %zu, %zu, %zu) \n " , ne[0 ], ne[1 ], ne[2 ], ne[3 ]);
170175 ggml_free (ctx0);
171176 return gf;
172177}
@@ -218,173 +223,39 @@ int main(void)
218223
219224 struct ggml_cgraph * gf_res = compute_graph (model, allocr);
220225
221- struct ggml_tensor * im2col_res = NULL ;
226+ struct ggml_tensor * wino_res = NULL ;
222227 struct ggml_tensor * conv2d_res = NULL ;
223228
224229 for (int i = 0 ; i < ggml_graph_n_nodes (gf_res); ++i) {
225- if (strcmp (ggml_get_name (ggml_graph_node (gf_res, i)), " im2col_res " ) == 0 ) {
226- im2col_res = ggml_graph_node (gf_res, i);
230+ if (strcmp (ggml_get_name (ggml_graph_node (gf_res, i)), " wino_res " ) == 0 ) {
231+ wino_res = ggml_graph_node (gf_res, i);
227232 } else if (strcmp (ggml_get_name (ggml_graph_node (gf_res, i)), " conv2d_res" ) == 0 ) {
228233 conv2d_res = ggml_graph_node (gf_res, i);
229234 }
230235 }
231236
232- std::vector<uint16_t > im2col_data (ggml_nelements (im2col_res ));
237+ std::vector<float > wino_data (ggml_nelements (wino_res ));
233238 std::vector<float > conv2d_data (ggml_nelements (conv2d_res));
234239
235- ggml_backend_tensor_get (im2col_res, im2col_data .data (), 0 , ggml_nbytes (im2col_res ));
240+ ggml_backend_tensor_get (wino_res, wino_data .data (), 0 , ggml_nbytes (wino_res ));
236241 ggml_backend_tensor_get (conv2d_res, conv2d_data.data (), 0 , ggml_nbytes (conv2d_res));
237242
238- const int n_conv2d_test = 480 ;
239- const int n_im2col_test = 4320 ;
240-
241- float expected_conv2d [n_conv2d_test] = {
242- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
243- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
244- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
245- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
246- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
247- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
248- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
249- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
250- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
251- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
252- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
253- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
254- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
255- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
256- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
257- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
258- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
259- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
260- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
261- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
262- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
263- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
264- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
265- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
266- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
267- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
268- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
269- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
270- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
271- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
272- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
273- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
274- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
275- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
276- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
277- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
278- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
279- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
280- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
281- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
282- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
283- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
284- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
285- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
286- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
287- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
288- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
289- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
290- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
291- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
292- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
293- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
294- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
295- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
296- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f ,
297- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
298- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
299- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
300- 225 .00f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 337 .50f , 225 .00f ,
301- 150 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 225 .00f , 150 .00f };
302-
303- uint16_t expected_im2col[n_conv2d_test] = {
304- 0 , 0 , 0 , 0 , 15872 , 15872 , 0 , 15872 ,
305- 15872 , 0 , 0 , 0 , 0 , 15872 , 15872 , 0 ,
306- 15872 , 15872 , 0 , 0 , 0 , 0 , 15872 , 15872 ,
307- 0 , 15872 , 15872 , 0 , 0 , 0 , 0 , 15872 ,
308- 15872 , 0 , 15872 , 15872 , 0 , 0 , 0 , 0 ,
309- 15872 , 15872 , 0 , 15872 , 15872 , 0 , 0 , 0 ,
310- 0 , 15872 , 15872 , 0 , 15872 , 15872 , 0 , 0 ,
311- 0 , 0 , 15872 , 15872 , 0 , 15872 , 15872 , 0 ,
312- 0 , 0 , 0 , 15872 , 15872 , 0 , 15872 , 15872 ,
313- 0 , 0 , 0 , 0 , 15872 , 15872 , 0 , 15872 ,
314- 15872 , 0 , 0 , 0 , 0 , 15872 , 15872 , 0 ,
315- 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 ,
316- 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 ,
317- 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 ,
318- 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 ,
319- 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 ,
320- 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 ,
321- 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 ,
322- 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 ,
323- 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 ,
324- 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 ,
325- 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 ,
326- 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 ,
327- 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 ,
328- 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 ,
329- 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 ,
330- 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 ,
331- 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 ,
332- 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 ,
333- 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 ,
334- 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 ,
335- 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 ,
336- 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 ,
337- 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 ,
338- 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 ,
339- 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 ,
340- 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 ,
341- 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 ,
342- 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 ,
343- 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 ,
344- 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 ,
345- 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 ,
346- 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 ,
347- 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 ,
348- 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 ,
349- 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 ,
350- 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 ,
351- 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 ,
352- 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 ,
353- 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 ,
354- 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 ,
355- 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 ,
356- 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 , 0 ,
357- 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 , 15872 ,
358- 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 , 15872 ,
359- 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 , 15872 ,
360- 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 , 15872 ,
361- 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 , 15872 ,
362- 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0 , 15872 ,
363- 15872 , 15872 , 15872 , 15872 , 15872 , 0 , 0 , 0
364- };
365-
366- printf (" \n Performing test:\n " );
243+
244+ printf (" \n Performing test:\n " );
367245
368246 bool passed = true ;
369- for (int i = 0 ; i < n_conv2d_test; i++) {
370- if (
371- im2col_data[i] != expected_im2col[i]) {
372- passed = false ;
373- break ;
374- }
375- }
376-
377- printf (" ggml_im2col (%d): %s\n " , (int ) ggml_nelements (im2col_res), passed && (ggml_nelements (im2col_res) == n_im2col_test) ? " \033 [32mPASSED\033 [0m" : " \033 [31mFAILED\033 [0m" );
378-
379- passed = true ;
380- for (int i = 0 ; i < n_conv2d_test; i++) {
381- if (conv2d_data[i] != expected_conv2d[i]) {
382- passed = false ;
383- break ;
384- }
247+ // for(int i = 0; i < ggml_nelements(wino_res); i++) {
248+ for (int i = 0 ; i < 3 *28 ; i++) {
249+ float diff = fabs (conv2d_data[i] - wino_data[i]);
250+ // if(diff > 1.e-4) {
251+ printf (" (%f, %f, %f, %d) \n " ,
252+ conv2d_data[i],
253+ wino_data[i], diff, i);
254+ // break;
255+ // }
385256 }
386257
387- printf ( " ggml_conv2d (%d): %s \n " , ( int ) ggml_nelements (conv2d_res), passed && ( ggml_nelements (conv2d_res) == n_conv2d_test) ? " \033 [32mPASSED \033 [0m " : " \033 [31mFAILED \033 [0m " );
258+
388259
389260 ggml_free (model.ctx );
390261
0 commit comments