Skip to content

Commit a47f6ce

Browse files
committed
Extend test case filtering
1. Allow passing multiple (comma-separated?) ops to test-backend-ops. This can be convenient when working on a set of ops, when you'd want to test them together (but without having to run every single op). For example: `test-backend-ops.exe test -o "ADD,RMS_NORM,ROPE,SILU,SOFT_MAX"` 2. Support full test-case variation string in addition to basic op names. This would make it easy to select a single variation, either for testing or for benchmarking. It can be particularly useful for profiling a particular variation (ex. a CUDA kernel), for example: `test-backend-ops.exe perf -b CUDA0 -o "MUL_MAT(type_a=f16,type_b=f32,m=4096,n=512,k=14336,bs=[1,1],nr=[1,1],per=[0,1,2,3],v=2)"` These two can be combined. As the current `-o`, this change doesn't try to detect/report an error if an filter doesn't name existing ops (ex. misspelled)
1 parent 3f4fc97 commit a47f6ce

File tree

1 file changed

+47
-16
lines changed

1 file changed

+47
-16
lines changed

tests/test-backend-ops.cpp

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
#include <random>
3636
#include <regex>
3737
#include <string>
38+
#include <string_view>
3839
#include <thread>
3940
#include <vector>
4041

@@ -1020,7 +1021,37 @@ struct test_case {
10201021
return t;
10211022
}
10221023

1023-
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name, printer * output_printer) {
1024+
// Checks an op against the test filter, which is a comma separated list of OP names or specific variations
1025+
bool matches_filter(ggml_tensor* op, const char* op_names_filter) {
1026+
if (op_names_filter) {
1027+
const auto op_name = op_desc(op);
1028+
const auto op_full_name = op_name + "(" + vars() + ")";
1029+
std::string_view filter(op_names_filter);
1030+
while (!filter.empty()) {
1031+
auto comma_pos = filter.find_first_of(',');
1032+
const auto lparen_pos = filter.find_first_of('(');
1033+
if (lparen_pos < comma_pos) {
1034+
auto rparen_pos = filter.find_first_of(')');
1035+
comma_pos = filter.find_first_of(',', rparen_pos);
1036+
const auto op_filter = filter.substr(0, comma_pos);
1037+
if (op_filter == op_full_name) {
1038+
return true;
1039+
}
1040+
} else {
1041+
const auto op_filter = filter.substr(0, comma_pos);
1042+
if (op_filter == op_name) {
1043+
return true;
1044+
}
1045+
}
1046+
filter = comma_pos != std::string_view::npos ? filter.substr(comma_pos + 1) : "";
1047+
}
1048+
return false;
1049+
} else {
1050+
return true;
1051+
}
1052+
}
1053+
1054+
bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_names_filter, printer * output_printer) {
10241055
mode = MODE_TEST;
10251056

10261057
ggml_init_params params = {
@@ -1038,7 +1069,7 @@ struct test_case {
10381069

10391070
ggml_tensor * out = build_graph(ctx);
10401071
std::string current_op_name = op_desc(out);
1041-
if (op_name != nullptr && current_op_name != op_name) {
1072+
if (!matches_filter(out, op_names_filter)) {
10421073
//printf(" %s: skipping\n", op_desc(out).c_str());
10431074
ggml_free(ctx);
10441075
return true;
@@ -1185,7 +1216,7 @@ struct test_case {
11851216
return test_passed;
11861217
}
11871218

1188-
bool eval_perf(ggml_backend_t backend, const char * op_name, printer * output_printer) {
1219+
bool eval_perf(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
11891220
mode = MODE_PERF;
11901221

11911222
static const size_t graph_nodes = 8192;
@@ -1200,7 +1231,7 @@ struct test_case {
12001231

12011232
ggml_tensor * out = build_graph(ctx.get());
12021233
std::string current_op_name = op_desc(out);
1203-
if (op_name != nullptr && current_op_name != op_name) {
1234+
if (!matches_filter(out, op_names_filter)) {
12041235
//printf(" %s: skipping\n", op_desc(out).c_str());
12051236
return true;
12061237
}
@@ -1315,7 +1346,7 @@ struct test_case {
13151346
return true;
13161347
}
13171348

1318-
bool eval_support(ggml_backend_t backend, const char * op_name, printer * output_printer) {
1349+
bool eval_support(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
13191350
mode = MODE_SUPPORT;
13201351

13211352
static const size_t graph_nodes = 8192;
@@ -1330,7 +1361,7 @@ struct test_case {
13301361

13311362
ggml_tensor * out = build_graph(ctx.get());
13321363
std::string current_op_name = op_desc(out);
1333-
if (op_name != nullptr && current_op_name != op_name) {
1364+
if (!matches_filter(out, op_names_filter)) {
13341365
return true;
13351366
}
13361367

@@ -1347,7 +1378,7 @@ struct test_case {
13471378
return true;
13481379
}
13491380

1350-
bool eval_grad(ggml_backend_t backend, const char * op_name, printer * output_printer) {
1381+
bool eval_grad(ggml_backend_t backend, const char * op_names_filter, printer * output_printer) {
13511382
mode = MODE_GRAD;
13521383
const std::vector<float> expect = grad_expect();
13531384

@@ -1364,7 +1395,7 @@ struct test_case {
13641395

13651396
ggml_tensor * out = build_graph(ctx.get());
13661397

1367-
if ((op_name != nullptr && op_desc(out) != op_name) || out->op == GGML_OP_OPT_STEP_ADAMW) {
1398+
if (!matches_filter(out, op_names_filter) || out->op == GGML_OP_OPT_STEP_ADAMW) {
13681399
return true;
13691400
}
13701401

@@ -5881,7 +5912,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
58815912
return test_cases;
58825913
}
58835914

5884-
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name, const char * params_filter,
5915+
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_names_filter, const char * params_filter,
58855916
printer * output_printer) {
58865917
auto filter_test_cases = [](std::vector<std::unique_ptr<test_case>> & test_cases, const char * params_filter) {
58875918
if (params_filter == nullptr) {
@@ -5913,7 +5944,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59135944

59145945
size_t n_ok = 0;
59155946
for (auto & test : test_cases) {
5916-
if (test->eval(backend, backend_cpu, op_name, output_printer)) {
5947+
if (test->eval(backend, backend_cpu, op_names_filter, output_printer)) {
59175948
n_ok++;
59185949
}
59195950
}
@@ -5929,7 +5960,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59295960
filter_test_cases(test_cases, params_filter);
59305961
size_t n_ok = 0;
59315962
for (auto & test : test_cases) {
5932-
if (test->eval_grad(backend, op_name, output_printer)) {
5963+
if (test->eval_grad(backend, op_names_filter, output_printer)) {
59335964
n_ok++;
59345965
}
59355966
}
@@ -5942,7 +5973,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59425973
auto test_cases = make_test_cases_perf();
59435974
filter_test_cases(test_cases, params_filter);
59445975
for (auto & test : test_cases) {
5945-
test->eval_perf(backend, op_name, output_printer);
5976+
test->eval_perf(backend, op_names_filter, output_printer);
59465977
}
59475978
return true;
59485979
}
@@ -5951,7 +5982,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
59515982
auto test_cases = make_test_cases_eval();
59525983
filter_test_cases(test_cases, params_filter);
59535984
for (auto & test : test_cases) {
5954-
test->eval_support(backend, op_name, output_printer);
5985+
test->eval_support(backend, op_names_filter, output_printer);
59555986
}
59565987
return true;
59575988
}
@@ -5973,7 +6004,7 @@ static void usage(char ** argv) {
59736004
int main(int argc, char ** argv) {
59746005
test_mode mode = MODE_TEST;
59756006
output_formats output_format = CONSOLE;
5976-
const char * op_name_filter = nullptr;
6007+
const char * op_names_filter = nullptr;
59776008
const char * backend_filter = nullptr;
59786009
const char * params_filter = nullptr;
59796010

@@ -5988,7 +6019,7 @@ int main(int argc, char ** argv) {
59886019
mode = MODE_SUPPORT;
59896020
} else if (strcmp(argv[i], "-o") == 0) {
59906021
if (i + 1 < argc) {
5991-
op_name_filter = argv[++i];
6022+
op_names_filter = argv[++i];
59926023
} else {
59936024
usage(argv);
59946025
return 1;
@@ -6069,7 +6100,7 @@ int main(int argc, char ** argv) {
60696100
false, "", ggml_backend_dev_description(dev),
60706101
total / 1024 / 1024, free / 1024 / 1024, true));
60716102

6072-
bool ok = test_backend(backend, mode, op_name_filter, params_filter, output_printer.get());
6103+
bool ok = test_backend(backend, mode, op_names_filter, params_filter, output_printer.get());
60736104

60746105
if (ok) {
60756106
n_ok++;

0 commit comments

Comments
 (0)