2323#include < algorithm>
2424#include < array>
2525#include < cfloat>
26+ #include < cinttypes>
2627#include < cstdint>
28+ #include < cstdio>
29+ #include < cstdlib>
2730#include < cstring>
28- #include < cinttypes >
31+ #include < future >
2932#include < memory>
3033#include < random>
31- #include < stdio.h>
32- #include < stdlib.h>
34+ #include < regex>
3335#include < string>
3436#include < thread>
35- #include < future>
3637#include < vector>
3738
3839static void init_tensor_uniform (ggml_tensor * tensor, float min = -1 .0f , float max = 1 .0f ) {
@@ -4382,9 +4383,27 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
43824383 return test_cases;
43834384}
43844385
4385- static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_name) {
4386+ static bool test_backend (ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter) {
4387+ auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
4388+ if (params_filter == nullptr ) {
4389+ return ;
4390+ }
4391+
4392+ std::regex params_filter_regex (params_filter);
4393+
4394+ for (auto it = test_cases.begin (); it != test_cases.end ();) {
4395+ if (!std::regex_search ((*it)->vars (), params_filter_regex)) {
4396+ it = test_cases.erase (it);
4397+ continue ;
4398+ }
4399+
4400+ it++;
4401+ }
4402+ };
4403+
43864404 if (mode == MODE_TEST) {
43874405 auto test_cases = make_test_cases_eval ();
4406+ filter_test_cases (test_cases, params_filter);
43884407 ggml_backend_t backend_cpu = ggml_backend_init_by_type (GGML_BACKEND_DEVICE_TYPE_CPU, NULL );
43894408 if (backend_cpu == NULL ) {
43904409 printf (" Failed to initialize CPU backend\n " );
@@ -4406,6 +4425,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44064425
44074426 if (mode == MODE_GRAD) {
44084427 auto test_cases = make_test_cases_eval ();
4428+ filter_test_cases (test_cases, params_filter);
44094429 size_t n_ok = 0 ;
44104430 for (auto & test : test_cases) {
44114431 if (test->eval_grad (backend, op_name)) {
@@ -4419,6 +4439,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44194439
44204440 if (mode == MODE_PERF) {
44214441 auto test_cases = make_test_cases_perf ();
4442+ filter_test_cases (test_cases, params_filter);
44224443 for (auto & test : test_cases) {
44234444 test->eval_perf (backend, op_name);
44244445 }
@@ -4429,7 +4450,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
44294450}
44304451
44314452static void usage (char ** argv) {
4432- printf (" Usage: %s [mode] [-o op ] [-b backend]\n " , argv[0 ]);
4453+ printf (" Usage: %s [mode] [-o <op> ] [-b < backend>] [-p <params regex> ]\n " , argv[0 ]);
44334454 printf (" valid modes:\n " );
44344455 printf (" - test (default, compare with CPU backend for correctness)\n " );
44354456 printf (" - grad (compare gradients from backpropagation with method of finite differences)\n " );
@@ -4439,8 +4460,9 @@ static void usage(char ** argv) {
44394460
44404461int main (int argc, char ** argv) {
44414462 test_mode mode = MODE_TEST;
4442- const char * op_name_filter = NULL ;
4443- const char * backend_filter = NULL ;
4463+ const char * op_name_filter = nullptr ;
4464+ const char * backend_filter = nullptr ;
4465+ const char * params_filter = nullptr ;
44444466
44454467 for (int i = 1 ; i < argc; i++) {
44464468 if (strcmp (argv[i], " test" ) == 0 ) {
@@ -4463,6 +4485,13 @@ int main(int argc, char ** argv) {
44634485 usage (argv);
44644486 return 1 ;
44654487 }
4488+ } else if (strcmp (argv[i], " -p" ) == 0 ) {
4489+ if (i + 1 < argc) {
4490+ params_filter = argv[++i];
4491+ } else {
4492+ usage (argv);
4493+ return 1 ;
4494+ }
44664495 } else {
44674496 usage (argv);
44684497 return 1 ;
@@ -4509,7 +4538,7 @@ int main(int argc, char ** argv) {
45094538 printf (" Device memory: %zu MB (%zu MB free)\n " , total / 1024 / 1024 , free / 1024 / 1024 );
45104539 printf (" \n " );
45114540
4512- bool ok = test_backend (backend, mode, op_name_filter);
4541+ bool ok = test_backend (backend, mode, op_name_filter, params_filter );
45134542
45144543 printf (" Backend %s: " , ggml_backend_name (backend));
45154544 if (ok) {
0 commit comments