@@ -438,7 +438,6 @@ static std::pair<int, int> test_forward_backward(
438438        float  weights;
439439        ggml_backend_tensor_get (cd.weights , &weights, 0 , sizeof (float ));
440440        const  bool  subtest_ok = weights == -ndata * .5 ;
441-         TEST_LOG (" %s: ndata=%d weights=%f\n " int ) ndata, (double ) weights);
442441        helper_after_test_forward_backward (optim, __func__, high_level, shuffle, " weights_after_forward_backward" 
443442    }
444443    {
@@ -821,11 +820,7 @@ static std::pair<int, int> test_regression(
821820        ggml_backend_tensor_get (b, &b_fit, 0 , sizeof (float ));
822821        float  tol = adamw ? 1e-2  : 5e-2 ;
823822        const  bool  aok = almost_equal (a_fit, a_true, tol);
824-         if  (!aok)
825-           TEST_LOG (" %s: a_fit=%f a_true=%f\n " double )a_fit, (double )a_true);
826823        const  bool  bok = almost_equal (b_fit, b_true, tol);
827-         if  (!bok)
828-           TEST_LOG (" %s: b_fit=%f b_true=%f\n " double )b_fit, (double )b_true);
829824        const  bool  subtest_ok = aok && bok;
830825        print_ok (__func__, adamw ? subtest_ok : true , npass, ntest, " subtest=weights" 
831826    }
@@ -934,19 +929,49 @@ int main(void) {
934929            printf ("   Device memory: %zu MB (%zu MB free)\n " 1024  / 1024 , free / 1024  / 1024 );
935930            printf (" \n " 
936931
937-             if  (optim == GGML_OPT_OPTIMIZER_TYPE_SGD && !strcmp (devname, " Vulkan0" 
938-               // TODO: even though backend returns false for currently
939-               //  unimplemented sgd op, we still need this
940-               continue ;
941-             if  (!strcmp (devname, " WebGPU" 
942-               //  GGML_OP_SUM implementation missing
943-               continue ;
944-             std::pair<int , int > result = test_backend (backend_sched, backends[i], optim);
932+             bool  skip;
933+             {
934+                 struct  ggml_init_params  params = {
935+                     /* .mem_size   =*/ 6 *ggml_tensor_overhead (),
936+                     /* .mem_buffer =*/ nullptr ,
937+                     /* .no_alloc   =*/ true ,
938+                 };
939+                 ggml_context * ctx = ggml_init (params);
940+                 ggml_tensor * a = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 1 );
941+                 ggml_set_param (a);
942+                 ggml_tensor * b = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 1 );
943+                 ggml_tensor * c = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 1 );
944+                 ggml_tensor * d = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 1 );
945+ 
946+                 ggml_tensor * t = nullptr ;
947+                 switch  (optim) {
948+                     case  GGML_OPT_OPTIMIZER_TYPE_ADAMW: {
949+                         ggml_tensor * p = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 7 );
950+                         t = ggml_opt_step_adamw (ctx, a, b, c, d, p);
951+                     } break ;
952+                     case  GGML_OPT_OPTIMIZER_TYPE_SGD: {
953+                         ggml_tensor * p = ggml_new_tensor_1d (ctx, GGML_TYPE_F32, 2 );
954+                         t = ggml_opt_step_sgd (ctx, a, b, p);
955+                     } break ;
956+                     case  GGML_OPT_OPTIMIZER_TYPE_COUNT: {
957+                         GGML_ABORT (" fatal error" 
958+                     }
959+                 }
960+                 skip = !ggml_backend_supports_op (backends[i], t);
961+                 ggml_free (ctx);
962+             }
945963
946-             printf ("   %d/%d tests passed\n " first , result.second );
964+             std::pair<int , int > result;
965+             if  (!skip) {
966+                 result = test_backend (backend_sched, backends[i], optim);
967+                 printf ("   %d/%d tests passed\n " first , result.second );
968+             }
947969
948970            printf ("   Backend %s %s: " ggml_backend_name (backends[i]), ggml_opt_optimizer_name (optim));
949-             if  (result.first  == result.second ) {
971+             if  (skip) {
972+                 printf (" \033 [0;33mSKIPPED\033 [0m\n " 
973+                 n_ok++;
974+             } else  if  (result.first  == result.second ) {
950975                printf (" \033 [1;32mOK\033 [0m\n " 
951976                n_ok++;
952977            } else  {
0 commit comments