@@ -2347,11 +2347,12 @@ struct test_soft_max : public test_case {
23472347    const  ggml_type type;
23482348    const  std::array<int64_t , 4 > ne;
23492349    const  bool  mask;
2350+     const  ggml_type m_prec;
23502351    const  float  scale;
23512352    const  float  max_bias;
23522353
23532354    std::string vars () override  {
2354-         return  VARS_TO_STR5 (type, ne, mask, scale, max_bias);
2355+         return  VARS_TO_STR6 (type, ne, mask, m_prec , scale, max_bias);
23552356    }
23562357
23572358    //  the 1024 test with bias occasionally fails:
@@ -2363,9 +2364,10 @@ struct test_soft_max : public test_case {
23632364    test_soft_max (ggml_type type = GGML_TYPE_F32,
23642365            std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
23652366            bool  mask = false ,
2367+             ggml_type m_prec = GGML_TYPE_F32,
23662368            float  scale = 1 .0f ,
23672369            float  max_bias = 0 .0f )
2368-         : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2370+         : type(type), ne(ne), mask(mask), m_prec(m_prec),  scale(scale), max_bias(max_bias) {}
23692371
23702372    ggml_tensor * build_graph (ggml_context * ctx) override  {
23712373        ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
@@ -2374,7 +2376,7 @@ struct test_soft_max : public test_case {
23742376
23752377        ggml_tensor * mask = nullptr ;
23762378        if  (this ->mask ) {
2377-             mask = ggml_new_tensor_2d (ctx, GGML_TYPE_F32 , ne[0 ], ne[1 ]);
2379+             mask = ggml_new_tensor_2d (ctx, m_prec , ne[0 ], ne[1 ]);
23782380            ggml_set_name (mask, " mask"  );
23792381        }
23802382
@@ -4150,17 +4152,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
41504152            for  (float  scale : {1 .0f , 0 .1f }) {
41514153                for  (int64_t  ne0 : {16 , 1024 }) {
41524154                    for  (int64_t  ne1 : {16 , 1024 }) {
4153-                         test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0,   ne1,   1 , 1 }, mask, scale, max_bias));
4154-                         test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, scale, max_bias));
4155+                         if  (mask) {
4156+                             for  (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4157+                                 test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0,   ne1,   1 , 1 }, mask, m_prec, scale, max_bias));
4158+                                 test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, scale, max_bias));
4159+                             }
4160+                         } else  {
4161+                             /*  The precision of mask here doesn't matter as boolean mask is false */ 
4162+                             test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0,   ne1,   1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4163+                             test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4164+                         }
41554165                    }
41564166                }
41574167            }
41584168        }
41594169    }
4160-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true ,  0 .1f , 0 .0f ));
4161-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , 0 .1f , 0 .0f ));
4162-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true ,  0 .1f , 0 .0f ));
4163-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true ,  0 .1f , 8 .0f ));
4170+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F32,  0 .1f , 0 .0f ));
4171+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F16,  0 .1f , 0 .0f ));
4172+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4173+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32,  0 .1f , 0 .0f ));
4174+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16,  0 .1f , 0 .0f ));
4175+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32,  0 .1f , 8 .0f ));
4176+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16,  0 .1f , 8 .0f ));
41644177
41654178    for  (float  max_bias : {0 .0f , 8 .0f }) {
41664179        for  (float  scale : {1 .0f , 0 .1f }) {
@@ -4296,13 +4309,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
42964309    test_cases.emplace_back (new  test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {8192 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
42974310    test_cases.emplace_back (new  test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {3072 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
42984311
4299-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4300-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4301-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4302-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4303-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4304-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4305-     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4312+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , GGML_TYPE_F32,  1 .0f , 0 .0f ));
4313+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , GGML_TYPE_F32,  1 .0f , 0 .0f ));
4314+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , GGML_TYPE_F32,  1 .0f , 0 .0f ));
4315+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , GGML_TYPE_F32,  1 .0f , 0 .0f ));
4316+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , GGML_TYPE_F32,  1 .0f , 0 .0f ));
4317+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , GGML_TYPE_F32,  1 .0f , 0 .0f ));
4318+     test_cases.emplace_back (new  test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , GGML_TYPE_F32,  1 .0f , 0 .0f ));
43064319
43074320    test_cases.emplace_back (new  test_argmax (GGML_TYPE_F32, {32 , 10 , 1 , 1 }));
43084321    test_cases.emplace_back (new  test_argmax (GGML_TYPE_F32, {1024 , 10 , 1 , 1 }));
0 commit comments