Skip to content

Commit 93c3da7

Browse files
author
bssrdf
committed
restore test-conv2d.cpp test
1 parent 0491858 commit 93c3da7

File tree

1 file changed

+20
-22
lines changed

1 file changed

+20
-22
lines changed

tests/test-conv2d.cpp

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ struct test_model {
3636

3737
void load_model(test_model & model, bool use_gpu = false) {
3838
// create data
39-
int KW = 3, KH = 3, IC = 32, OC = 32;
40-
int IW = 28, IH = 40, N = 1;
39+
int KW = 3, KH = 3, IC = 10, OC = 10;
40+
int IW = 8, IH = 6, N = 1;
4141

4242
// Initialize adata
4343
std::vector<float> adata(KW * KH * IC * OC);
@@ -365,28 +365,26 @@ int main(void)
365365

366366
printf("\nPerforming test:\n");
367367

368-
// bool passed = true;
369-
// for(int i = 0; i < n_conv2d_test; i++) {
370-
// if(
371-
// im2col_data[i] != expected_im2col[i]) {
372-
// passed = false;
373-
// break;
374-
// }
375-
// }
376-
377-
// printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
378-
379-
// passed = true;
380-
// printf("[");
381-
for(int j = 0; j < 4; j++) {
382-
printf("[");
383-
for(int i = 0; i < 28; i++) {
384-
printf("%.1f, ", conv2d_data[i]);
368+
bool passed = true;
369+
for(int i = 0; i < n_conv2d_test; i++) {
370+
if(
371+
im2col_data[i] != expected_im2col[i]) {
372+
passed = false;
373+
break;
385374
}
386-
printf("]\n");
387375
}
388-
389-
// printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
376+
377+
printf("ggml_im2col (%d): %s\n", (int) ggml_nelements(im2col_res), passed && (ggml_nelements(im2col_res) == n_im2col_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
378+
379+
passed = true;
380+
for(int i = 0; i < n_conv2d_test; i++) {
381+
if(conv2d_data[i] != expected_conv2d[i]) {
382+
passed = false;
383+
break;
384+
}
385+
}
386+
387+
printf("ggml_conv2d (%d): %s\n", (int) ggml_nelements(conv2d_res), passed && (ggml_nelements(conv2d_res) == n_conv2d_test) ? "\033[32mPASSED\033[0m" : "\033[31mFAILED\033[0m");
390388

391389
ggml_free(model.ctx);
392390

0 commit comments

Comments
 (0)