2323#include  < algorithm> 
2424#include  < array> 
2525#include  < cfloat> 
26+ #include  < cinttypes> 
2627#include  < cstdint> 
28+ #include  < cstdio> 
29+ #include  < cstdlib> 
2730#include  < cstring> 
28- #include  < cinttypes > 
31+ #include  < future > 
2932#include  < memory> 
3033#include  < random> 
31- #include  < stdio.h> 
32- #include  < stdlib.h> 
34+ #include  < regex> 
3335#include  < string> 
3436#include  < thread> 
35- #include  < future> 
3637#include  < vector> 
3738
3839static  void  init_tensor_uniform (ggml_tensor * tensor, float  min = -1 .0f , float  max = 1 .0f ) {
@@ -4382,9 +4383,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
43824383    return  test_cases;
43834384}
43844385
4385- static  bool  test_backend (ggml_backend_t  backend, test_mode mode, const  char  * op_name) {
4386+ static  bool  test_backend (ggml_backend_t  backend, test_mode mode, const  char  * op_name, const  char  * params_filter) {
4387+     auto  filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const  char  * params_filter) {
4388+         if  (params_filter == nullptr ) {
4389+             return ;
4390+         }
4391+ 
4392+         std::regex params_filter_regex (params_filter);
4393+ 
4394+         for  (auto  it = test_cases.begin (); it != test_cases.end ();) {
4395+             if  (!std::regex_search ((*it)->vars (), params_filter_regex)) {
4396+                 it = test_cases.erase (it);
4397+                 continue ;
4398+             }
4399+ 
4400+             it++;
4401+         }
4402+     };
4403+ 
43864404    if  (mode == MODE_TEST) {
43874405        auto  test_cases = make_test_cases_eval ();
4406+         filter_test_cases (test_cases, params_filter);
43884407        ggml_backend_t  backend_cpu = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, NULL );
43894408        if  (backend_cpu == NULL ) {
43904409            printf ("   Failed to initialize CPU backend\n "  );
@@ -4406,6 +4425,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44064425
44074426    if  (mode == MODE_GRAD) {
44084427        auto  test_cases = make_test_cases_eval ();
4428+         filter_test_cases (test_cases, params_filter);
44094429        size_t  n_ok = 0 ;
44104430        for  (auto  & test : test_cases) {
44114431            if  (test->eval_grad (backend, op_name)) {
@@ -4419,6 +4439,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44194439
44204440    if  (mode == MODE_PERF) {
44214441        auto  test_cases = make_test_cases_perf ();
4442+         filter_test_cases (test_cases, params_filter);
44224443        for  (auto  & test : test_cases) {
44234444            test->eval_perf (backend, op_name);
44244445        }
@@ -4429,7 +4450,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44294450}
44304451
44314452static  void  usage (char  ** argv) {
4432-     printf (" Usage: %s [mode] [-o op ] [-b backend]\n "  , argv[0 ]);
4453+     printf (" Usage: %s [mode] [-o <op> ] [-b < backend>] [-p <params regex> ]\n "  , argv[0 ]);
44334454    printf ("     valid modes:\n "  );
44344455    printf ("       - test (default, compare with CPU backend for correctness)\n "  );
44354456    printf ("       - grad (compare gradients from backpropagation with method of finite differences)\n "  );
@@ -4439,8 +4460,9 @@ static void usage(char ** argv) {
44394460
44404461int  main (int  argc, char  ** argv) {
44414462    test_mode mode = MODE_TEST;
4442-     const  char  * op_name_filter = NULL ;
4443-     const  char  * backend_filter = NULL ;
4463+     const  char  * op_name_filter = nullptr ;
4464+     const  char  * backend_filter = nullptr ;
4465+     const  char  * params_filter = nullptr ;
44444466
44454467    for  (int  i = 1 ; i < argc; i++) {
44464468        if  (strcmp (argv[i], " test"  ) == 0 ) {
@@ -4463,6 +4485,13 @@ int main(int argc, char ** argv) {
44634485                usage (argv);
44644486                return  1 ;
44654487            }
4488+         } else  if  (strcmp (argv[i], " -p"  ) == 0 ) {
4489+             if  (i + 1  < argc) {
4490+                 params_filter = argv[++i];
4491+             } else  {
4492+                 usage (argv);
4493+                 return  1 ;
4494+             }
44664495        } else  {
44674496            usage (argv);
44684497            return  1 ;
@@ -4509,7 +4538,7 @@ int main(int argc, char ** argv) {
45094538        printf ("   Device memory: %zu MB (%zu MB free)\n "  , total / 1024  / 1024 , free / 1024  / 1024 );
45104539        printf (" \n "  );
45114540
4512-         bool  ok = test_backend (backend, mode, op_name_filter);
4541+         bool  ok = test_backend (backend, mode, op_name_filter, params_filter );
45134542
45144543        printf ("   Backend %s: "  , ggml_backend_name (backend));
45154544        if  (ok) {
0 commit comments