3535#include < random>
3636#include < regex>
3737#include < string>
38+ #include < string_view>
3839#include < thread>
3940#include < vector>
4041
@@ -1047,7 +1048,37 @@ struct test_case {
10471048 return t;
10481049 }
10491050
1050- bool eval (ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) {
1051+ // Checks an op against the test filter, which is a comma separated list of OP names or specific variations
1052+ bool matches_filter (ggml_tensor * op, const char * op_names_filter) {
1053+ if (op_names_filter) {
1054+ const auto op_name = op_desc (op);
1055+ const auto op_full_name = op_name + " (" + vars () + " )" ;
1056+ std::string_view filter (op_names_filter);
1057+ while (!filter.empty ()) {
1058+ auto comma_pos = filter.find_first_of (' ,' );
1059+ const auto lparen_pos = filter.find_first_of (' (' );
1060+ if (lparen_pos < comma_pos) {
1061+ auto rparen_pos = filter.find_first_of (' )' );
1062+ comma_pos = filter.find_first_of (' ,' , rparen_pos);
1063+ const auto op_filter = filter.substr (0 , comma_pos);
1064+ if (op_filter == op_full_name) {
1065+ return true ;
1066+ }
1067+ } else {
1068+ const auto op_filter = filter.substr (0 , comma_pos);
1069+ if (op_filter == op_name) {
1070+ return true ;
1071+ }
1072+ }
1073+ filter = comma_pos != std::string_view::npos ? filter.substr (comma_pos + 1 ) : " " ;
1074+ }
1075+ return false ;
1076+ } else {
1077+ return true ;
1078+ }
1079+ }
1080+
1081+ bool eval (ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
10511082 mode = MODE_TEST;
10521083
10531084 ggml_init_params params = {
@@ -1065,7 +1096,7 @@ struct test_case {
10651096
10661097 ggml_tensor * out = build_graph (ctx);
10671098 std::string current_op_name = op_desc (out);
1068- if (op_name != nullptr && current_op_name != op_name ) {
1099+ if (! matches_filter (out, op_names_filter) ) {
10691100 // printf(" %s: skipping\n", op_desc(out).c_str());
10701101 ggml_free (ctx);
10711102 return true ;
@@ -1212,7 +1243,7 @@ struct test_case {
12121243 return test_passed;
12131244 }
12141245
1215- bool eval_perf (ggml_backend_t backend, const char * op_name , printer * output_printer) {
1246+ bool eval_perf (ggml_backend_t backend, const char * op_names_filter , printer * output_printer) {
12161247 mode = MODE_PERF;
12171248
12181249 static const size_t graph_nodes = 8192 ;
@@ -1227,7 +1258,7 @@ struct test_case {
12271258
12281259 ggml_tensor * out = build_graph (ctx.get ());
12291260 std::string current_op_name = op_desc (out);
1230- if (op_name != nullptr && current_op_name != op_name ) {
1261+ if (! matches_filter (out, op_names_filter) ) {
12311262 // printf(" %s: skipping\n", op_desc(out).c_str());
12321263 return true ;
12331264 }
@@ -1342,7 +1373,7 @@ struct test_case {
13421373 return true ;
13431374 }
13441375
1345- bool eval_support (ggml_backend_t backend, const char * op_name , printer * output_printer) {
1376+ bool eval_support (ggml_backend_t backend, const char * op_names_filter , printer * output_printer) {
13461377 mode = MODE_SUPPORT;
13471378
13481379 static const size_t graph_nodes = 8192 ;
@@ -1357,7 +1388,7 @@ struct test_case {
13571388
13581389 ggml_tensor * out = build_graph (ctx.get ());
13591390 std::string current_op_name = op_desc (out);
1360- if (op_name != nullptr && current_op_name != op_name ) {
1391+ if (! matches_filter (out, op_names_filter) ) {
13611392 return true ;
13621393 }
13631394
@@ -1374,7 +1405,7 @@ struct test_case {
13741405 return true ;
13751406 }
13761407
1377- bool eval_grad (ggml_backend_t backend, const char * op_name , printer * output_printer) {
1408+ bool eval_grad (ggml_backend_t backend, const char * op_names_filter , printer * output_printer) {
13781409 mode = MODE_GRAD;
13791410 const std::vector<float > expect = grad_expect ();
13801411
@@ -1391,7 +1422,7 @@ struct test_case {
13911422
13921423 ggml_tensor * out = build_graph (ctx.get ());
13931424
1394- if ((op_name != nullptr && op_desc (out) != op_name ) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1425+ if (! matches_filter (out, op_names_filter ) || out->op == GGML_OP_OPT_STEP_ADAMW) {
13951426 return true ;
13961427 }
13971428
@@ -5922,7 +5953,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
59225953 return test_cases;
59235954}
59245955
5925- static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_name , const char * params_filter,
5956+ static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_names_filter , const char * params_filter,
59265957 printer * output_printer) {
59275958 auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
59285959 if (params_filter == nullptr ) {
@@ -5954,7 +5985,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59545985
59555986 size_t n_ok = 0 ;
59565987 for (auto & test : test_cases) {
5957- if (test->eval (backend, backend_cpu, op_name , output_printer)) {
5988+ if (test->eval (backend, backend_cpu, op_names_filter , output_printer)) {
59585989 n_ok++;
59595990 }
59605991 }
@@ -5970,7 +6001,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59706001 filter_test_cases (test_cases, params_filter);
59716002 size_t n_ok = 0 ;
59726003 for (auto & test : test_cases) {
5973- if (test->eval_grad (backend, op_name , output_printer)) {
6004+ if (test->eval_grad (backend, op_names_filter , output_printer)) {
59746005 n_ok++;
59756006 }
59766007 }
@@ -5983,7 +6014,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59836014 auto test_cases = make_test_cases_perf ();
59846015 filter_test_cases (test_cases, params_filter);
59856016 for (auto & test : test_cases) {
5986- test->eval_perf (backend, op_name , output_printer);
6017+ test->eval_perf (backend, op_names_filter , output_printer);
59876018 }
59886019 return true ;
59896020 }
@@ -5992,7 +6023,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59926023 auto test_cases = make_test_cases_eval ();
59936024 filter_test_cases (test_cases, params_filter);
59946025 for (auto & test : test_cases) {
5995- test->eval_support (backend, op_name , output_printer);
6026+ test->eval_support (backend, op_names_filter , output_printer);
59966027 }
59976028 return true ;
59986029 }
@@ -6001,20 +6032,21 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
60016032}
60026033
60036034static void usage (char ** argv) {
6004- printf (" Usage: %s [mode] [-o <op>] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n " , argv[0 ]);
6035+ printf (" Usage: %s [mode] [-o <op,.. >] [-b <backend>] [-p <params regex>] [--output <console|sql|csv>]\n " , argv[0 ]);
60056036 printf (" valid modes:\n " );
60066037 printf (" - test (default, compare with CPU backend for correctness)\n " );
60076038 printf (" - grad (compare gradients from backpropagation with method of finite differences)\n " );
60086039 printf (" - perf (performance evaluation)\n " );
60096040 printf (" - support (probe backend operation support)\n " );
6010- printf (" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc)\n " );
6041+ printf (" op names for -o are as given by ggml_op_desc() (e.g. ADD, MUL_MAT, etc),\n " );
6042+ printf (" optionally including the full test case string (e.g. \" ADD(type=f16,ne=[1,1,8,1],nr=[1,1,1,1],nf=1)\" )\n " );
60116043 printf (" --output specifies output format (default: console, options: console, sql, csv)\n " );
60126044}
60136045
60146046int main (int argc, char ** argv) {
60156047 test_mode mode = MODE_TEST;
60166048 output_formats output_format = CONSOLE;
6017- const char * op_name_filter = nullptr ;
6049+ const char * op_names_filter = nullptr ;
60186050 const char * backend_filter = nullptr ;
60196051 const char * params_filter = nullptr ;
60206052
@@ -6029,7 +6061,7 @@ int main(int argc, char ** argv) {
60296061 mode = MODE_SUPPORT;
60306062 } else if (strcmp (argv[i], " -o" ) == 0 ) {
60316063 if (i + 1 < argc) {
6032- op_name_filter = argv[++i];
6064+ op_names_filter = argv[++i];
60336065 } else {
60346066 usage (argv);
60356067 return 1 ;
@@ -6110,7 +6142,7 @@ int main(int argc, char ** argv) {
61106142 false , " " , ggml_backend_dev_description (dev),
61116143 total / 1024 / 1024 , free / 1024 / 1024 , true ));
61126144
6113- bool ok = test_backend (backend, mode, op_name_filter , params_filter, output_printer.get ());
6145+ bool ok = test_backend (backend, mode, op_names_filter , params_filter, output_printer.get ());
61146146
61156147 if (ok) {
61166148 n_ok++;
0 commit comments