3535#include  < random> 
3636#include  < regex> 
3737#include  < string> 
38+ #include  < string_view> 
3839#include  < thread> 
3940#include  < vector> 
4041
@@ -1020,7 +1021,37 @@ struct test_case {
10201021        return  t;
10211022    }
10221023
1023-     bool  eval (ggml_backend_t  backend1, ggml_backend_t  backend2, const  char  * op_name, printer * output_printer) {
1024+     //  Checks an op against the test filter, which is a comma separated list of OP names or specific variations
1025+     bool  matches_filter (ggml_tensor* op, const  char * op_names_filter) {
1026+         if  (op_names_filter) {
1027+             const  auto  op_name = op_desc (op);
1028+             const  auto  op_full_name = op_name + " ("   + vars () + " )"  ;
1029+             std::string_view filter (op_names_filter);
1030+             while  (!filter.empty ()) {
1031+                 auto  comma_pos = filter.find_first_of (' ,'  );
1032+                 const  auto  lparen_pos = filter.find_first_of (' ('  );
1033+                 if  (lparen_pos < comma_pos) {
1034+                     auto  rparen_pos = filter.find_first_of (' )'  );
1035+                     comma_pos = filter.find_first_of (' ,'  , rparen_pos);
1036+                     const  auto  op_filter = filter.substr (0 , comma_pos);
1037+                     if  (op_filter == op_full_name) {
1038+                         return  true ;
1039+                     }
1040+                 } else  {
1041+                     const  auto  op_filter = filter.substr (0 , comma_pos);
1042+                     if  (op_filter == op_name) {
1043+                         return  true ;
1044+                     }
1045+                 }
1046+                 filter = comma_pos != std::string_view::npos ? filter.substr (comma_pos + 1 ) : " "  ;
1047+             }
1048+             return  false ;
1049+         } else  {
1050+             return  true ;
1051+         }
1052+     }
1053+ 
1054+     bool  eval (ggml_backend_t  backend1, ggml_backend_t  backend2, const  char  * op_names_filter, printer * output_printer) {
10241055        mode = MODE_TEST;
10251056
10261057        ggml_init_params params = {
@@ -1038,7 +1069,7 @@ struct test_case {
10381069
10391070        ggml_tensor * out = build_graph (ctx);
10401071        std::string current_op_name = op_desc (out);
1041-         if  (op_name !=  nullptr  && current_op_name != op_name ) {
1072+         if  (! matches_filter (out, op_names_filter) ) {
10421073            // printf("  %s: skipping\n", op_desc(out).c_str());
10431074            ggml_free (ctx);
10441075            return  true ;
@@ -1185,7 +1216,7 @@ struct test_case {
11851216        return  test_passed;
11861217    }
11871218
1188-     bool  eval_perf (ggml_backend_t  backend, const  char  * op_name , printer * output_printer) {
1219+     bool  eval_perf (ggml_backend_t  backend, const  char  * op_names_filter , printer * output_printer) {
11891220        mode = MODE_PERF;
11901221
11911222        static  const  size_t  graph_nodes = 8192 ;
@@ -1200,7 +1231,7 @@ struct test_case {
12001231
12011232        ggml_tensor * out             = build_graph (ctx.get ());
12021233        std::string   current_op_name = op_desc (out);
1203-         if  (op_name !=  nullptr  && current_op_name != op_name ) {
1234+         if  (! matches_filter (out, op_names_filter) ) {
12041235            // printf("  %s: skipping\n", op_desc(out).c_str());
12051236            return  true ;
12061237        }
@@ -1315,7 +1346,7 @@ struct test_case {
13151346        return  true ;
13161347    }
13171348
1318-     bool  eval_support (ggml_backend_t  backend, const  char  * op_name , printer * output_printer) {
1349+     bool  eval_support (ggml_backend_t  backend, const  char  * op_names_filter , printer * output_printer) {
13191350        mode = MODE_SUPPORT;
13201351
13211352        static  const  size_t  graph_nodes = 8192 ;
@@ -1330,7 +1361,7 @@ struct test_case {
13301361
13311362        ggml_tensor * out             = build_graph (ctx.get ());
13321363        std::string   current_op_name = op_desc (out);
1333-         if  (op_name !=  nullptr  && current_op_name != op_name ) {
1364+         if  (! matches_filter (out, op_names_filter) ) {
13341365            return  true ;
13351366        }
13361367
@@ -1347,7 +1378,7 @@ struct test_case {
13471378        return  true ;
13481379    }
13491380
1350-     bool  eval_grad (ggml_backend_t  backend, const  char  * op_name , printer * output_printer) {
1381+     bool  eval_grad (ggml_backend_t  backend, const  char  * op_names_filter , printer * output_printer) {
13511382        mode = MODE_GRAD;
13521383        const  std::vector<float > expect = grad_expect ();
13531384
@@ -1364,7 +1395,7 @@ struct test_case {
13641395
13651396        ggml_tensor * out = build_graph (ctx.get ());
13661397
1367-         if  ((op_name !=  nullptr  &&  op_desc (out) != op_name ) || out->op  == GGML_OP_OPT_STEP_ADAMW) {
1398+         if  (! matches_filter (out, op_names_filter ) || out->op  == GGML_OP_OPT_STEP_ADAMW) {
13681399            return  true ;
13691400        }
13701401
@@ -5881,7 +5912,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
58815912    return  test_cases;
58825913}
58835914
5884- static  bool  test_backend (ggml_backend_t  backend, test_mode mode, const  char  * op_name , const  char  * params_filter,
5915+ static  bool  test_backend (ggml_backend_t  backend, test_mode mode, const  char  * op_names_filter , const  char  * params_filter,
58855916                         printer * output_printer) {
58865917    auto  filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const  char  * params_filter) {
58875918        if  (params_filter == nullptr ) {
@@ -5913,7 +5944,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59135944
59145945        size_t  n_ok = 0 ;
59155946        for  (auto  & test : test_cases) {
5916-             if  (test->eval (backend, backend_cpu, op_name , output_printer)) {
5947+             if  (test->eval (backend, backend_cpu, op_names_filter , output_printer)) {
59175948                n_ok++;
59185949            }
59195950        }
@@ -5929,7 +5960,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59295960        filter_test_cases (test_cases, params_filter);
59305961        size_t  n_ok = 0 ;
59315962        for  (auto  & test : test_cases) {
5932-             if  (test->eval_grad (backend, op_name , output_printer)) {
5963+             if  (test->eval_grad (backend, op_names_filter , output_printer)) {
59335964                n_ok++;
59345965            }
59355966        }
@@ -5942,7 +5973,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59425973        auto  test_cases = make_test_cases_perf ();
59435974        filter_test_cases (test_cases, params_filter);
59445975        for  (auto  & test : test_cases) {
5945-             test->eval_perf (backend, op_name , output_printer);
5976+             test->eval_perf (backend, op_names_filter , output_printer);
59465977        }
59475978        return  true ;
59485979    }
@@ -5951,7 +5982,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59515982        auto  test_cases = make_test_cases_eval ();
59525983        filter_test_cases (test_cases, params_filter);
59535984        for  (auto  & test : test_cases) {
5954-             test->eval_support (backend, op_name , output_printer);
5985+             test->eval_support (backend, op_names_filter , output_printer);
59555986        }
59565987        return  true ;
59575988    }
@@ -5973,7 +6004,7 @@ static void usage(char ** argv) {
59736004int  main (int  argc, char  ** argv) {
59746005    test_mode mode = MODE_TEST;
59756006    output_formats output_format = CONSOLE;
5976-     const  char  * op_name_filter  = nullptr ;
6007+     const  char  * op_names_filter  = nullptr ;
59776008    const  char  * backend_filter = nullptr ;
59786009    const  char  * params_filter = nullptr ;
59796010
@@ -5988,7 +6019,7 @@ int main(int argc, char ** argv) {
59886019            mode = MODE_SUPPORT;
59896020        } else  if  (strcmp (argv[i], " -o"  ) == 0 ) {
59906021            if  (i + 1  < argc) {
5991-                 op_name_filter  = argv[++i];
6022+                 op_names_filter  = argv[++i];
59926023            } else  {
59936024                usage (argv);
59946025                return  1 ;
@@ -6069,7 +6100,7 @@ int main(int argc, char ** argv) {
60696100                                                             false , " "  , ggml_backend_dev_description (dev),
60706101                                                             total / 1024  / 1024 , free / 1024  / 1024 , true ));
60716102
6072-         bool  ok = test_backend (backend, mode, op_name_filter , params_filter, output_printer.get ());
6103+         bool  ok = test_backend (backend, mode, op_names_filter , params_filter, output_printer.get ());
60736104
60746105        if  (ok) {
60756106            n_ok++;
0 commit comments