@@ -36,7 +36,9 @@ struct test_model {
3636 struct ggml_context * ctx;
3737};
3838
39-
39+ void load_model (test_model &, int , int , int , int , int , int , bool );
40+ struct ggml_cgraph * build_graph_0 (const test_model&);
41+ struct ggml_cgraph * build_graph_1 (const test_model&);
4042
4143void load_model (test_model & model, int ic, int oc, int iw, int ih, int kw = 3 , int kh = 3 , bool use_gpu = false ) {
4244 // create data
@@ -102,7 +104,6 @@ void load_model(test_model & model, int ic, int oc, int iw, int ih, int kw = 3,
102104#ifdef GGML_USE_METAL
103105 if (use_gpu) {
104106 fprintf (stderr, " %s: using Metal backend\n " , __func__);
105- ggml_backend_metal_log_set_callback (ggml_log_callback_default, nullptr );
106107 model.backend = ggml_backend_metal_init ();
107108 if (!model.backend ) {
108109 fprintf (stderr, " %s: ggml_backend_metal_init() failed\n " , __func__);
@@ -178,8 +179,6 @@ struct ggml_cgraph * build_graph_0(const test_model& model) {
178179 int d0 = 1 ;
179180 int d1 = 1 ;
180181
181-
182-
183182 // recalculate for avoid fragmentation
184183 struct ggml_tensor * conv2d_res = ggml_conv_2d (ctx0, model.a , model.b , s0, s1, p0, p1, d0, d1);
185184 ggml_set_name (conv2d_res, " conv2d_res" );
@@ -219,8 +218,6 @@ struct ggml_cgraph * build_graph_1(const test_model& model) {
219218 int d0 = 1 ;
220219 int d1 = 1 ;
221220
222-
223-
224221 // recalculate for avoid fragmentation
225222 // struct ggml_tensor* conv2d_res = ggml_conv_2d(ctx0, model.a, model.b, s0, s1, p0, p1, d0, d1);
226223 // ggml_set_name(conv2d_res, "conv2d_res");
@@ -239,7 +236,8 @@ struct ggml_cgraph * build_graph_1(const test_model& model) {
239236 return gf;
240237}
241238
242-
239+ std::vector<float > compute_graph (const test_model &, ggml_gallocr_t ,
240+ build_graph_t , int , double *);
243241
244242
245243std::vector<float > compute_graph (const test_model & model, ggml_gallocr_t allocr,
@@ -255,14 +253,6 @@ std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr
255253 ggml_backend_cpu_set_n_threads (model.backend , n_threads);
256254 }
257255
258- #ifdef GGML_USE_METAL
259- if (ggml_backend_is_metal (model.backend )) {
260- ggml_backend_metal_set_n_cb (model.backend , n_threads);
261- }
262- #endif
263-
264-
265-
266256 ggml_backend_graph_compute (model.backend , gf);
267257
268258 ggml_backend_synchronize (model.backend );
@@ -274,13 +264,11 @@ std::vector<float> compute_graph(const test_model & model, ggml_gallocr_t allocr
274264 ggml_backend_synchronize (model.backend );
275265 }
276266
277- // ggml_backend_synchronize(model.backend);
278267 int64_t end_time = ggml_time_us ();
279268 double time_us = end_time - start_time;
280269
281270 time_us = time_us/iters;
282- // printf(" Taking %f ms\n ", time_us/1000);
283-
271+
284272 // ggml_graph_print(gf);
285273
286274 struct ggml_tensor *res = NULL ;
@@ -334,7 +322,7 @@ int main(void)
334322
335323 for (auto c : configs){
336324 test_model model;
337- load_model (model, std::get<0 >(c), std::get<1 >(c), std::get<2 >(c),
325+ load_model (model, std::get<0 >(c), std::get<1 >(c), std::get<2 >(c),
338326 std::get<3 >(c), std::get<4 >(c), std::get<5 >(c), true );
339327
340328 ggml_gallocr_t allocr = NULL ;
@@ -349,7 +337,6 @@ int main(void)
349337 // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
350338
351339
352- struct ggml_cgraph * gf_res_0 = NULL ;
353340 int iterations = 20 ;
354341
355342 double run_time0;
@@ -368,15 +355,14 @@ int main(void)
368355 ggml_gallocr_reserve (allocr, gf);
369356 size_t mem_size1 = ggml_gallocr_get_buffer_size (allocr, 0 );
370357 // fprintf(stderr, "%s: compute buffer size: %.2f MB\n", __func__, mem_size/1024.0f/1024.0f);
371-
372358
373- struct ggml_cgraph * gf_res_1 = NULL ;
359+
374360
375361 double run_time1;
376362 // std::vector<float> wino_data = compute_graph(model, allocr, build_graph_1, iterations, &run_time1);
377363 std::vector<float > conv2d_data = compute_graph (model, allocr, build_graph_1, iterations, &run_time1);
378364
379- if (k==0 ) {
365+ if (k==0 ) {
380366 k = 1 ;
381367 fprintf (stderr, " | (IC, OC, IW, IH, KW, KH) | im2col+GEMM TIME | im2col+GEMM VRAM | implicit GEMM TIME | implicit GEMM VRAM \n " );
382368 fprintf (stderr, " | --- | --- | --- | --- | --- \n " );
@@ -409,6 +395,6 @@ int main(void)
409395
410396 }
411397
412- // printf("\nPerforming test:\n");
398+ // printf("\nPerforming test:\n");
413399 return 0 ;
414400}
0 commit comments