@@ -2282,11 +2282,12 @@ struct test_soft_max : public test_case {
22822282 const ggml_type type;
22832283 const std::array<int64_t , 4 > ne;
22842284 const bool mask;
2285+ const ggml_type m_prec;
22852286 const float scale;
22862287 const float max_bias;
22872288
22882289 std::string vars () override {
2289- return VARS_TO_STR5 (type, ne, mask, scale, max_bias);
2290+ return VARS_TO_STR6 (type, ne, mask, m_prec , scale, max_bias);
22902291 }
22912292
22922293 // the 1024 test with bias occasionally fails:
@@ -2298,9 +2299,10 @@ struct test_soft_max : public test_case {
22982299 test_soft_max (ggml_type type = GGML_TYPE_F32,
22992300 std::array<int64_t , 4 > ne = {10 , 5 , 4 , 3 },
23002301 bool mask = false ,
2302+ ggml_type m_prec = GGML_TYPE_F32,
23012303 float scale = 1 .0f ,
23022304 float max_bias = 0 .0f )
2303- : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
2305+ : type(type), ne(ne), mask(mask), m_prec(m_prec), scale(scale), max_bias(max_bias) {}
23042306
23052307 ggml_tensor * build_graph (ggml_context * ctx) override {
23062308 ggml_tensor * a = ggml_new_tensor (ctx, type, 4 , ne.data ());
@@ -2309,7 +2311,7 @@ struct test_soft_max : public test_case {
23092311
23102312 ggml_tensor * mask = nullptr ;
23112313 if (this ->mask ) {
2312- mask = ggml_new_tensor_2d (ctx, GGML_TYPE_F32 , ne[0 ], ne[1 ]);
2314+ mask = ggml_new_tensor_2d (ctx, m_prec , ne[0 ], ne[1 ]);
23132315 ggml_set_name (mask, " mask" );
23142316 }
23152317
@@ -4078,17 +4080,28 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
40784080 for (float scale : {1 .0f , 0 .1f }) {
40794081 for (int64_t ne0 : {16 , 1024 }) {
40804082 for (int64_t ne1 : {16 , 1024 }) {
4081- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, scale, max_bias));
4082- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, scale, max_bias));
4083+ if (mask) {
4084+ for (ggml_type m_prec : {GGML_TYPE_F32, GGML_TYPE_F16}) {
4085+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, m_prec, scale, max_bias));
4086+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, m_prec, scale, max_bias));
4087+ }
4088+ } else {
4089+ /* The precision of mask here doesn't matter as boolean mask is false */
4090+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0, ne1, 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4091+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {ne0-1 , ne1-1 , 1 , 1 }, mask, GGML_TYPE_F32, scale, max_bias));
4092+ }
40834093 }
40844094 }
40854095 }
40864096 }
40874097 }
4088- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4089- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , 0 .1f , 0 .0f ));
4090- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 0 .0f ));
4091- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , 0 .1f , 8 .0f ));
4098+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4099+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4100+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {16 , 2 , 32 , 1 }, false , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4101+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 0 .0f ));
4102+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 0 .0f ));
4103+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F32, 0 .1f , 8 .0f ));
4104+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {32 , 2 , 32 , 1 }, true , GGML_TYPE_F16, 0 .1f , 8 .0f ));
40924105
40934106 for (float max_bias : {0 .0f , 8 .0f }) {
40944107 for (float scale : {1 .0f , 0 .1f }) {
@@ -4224,13 +4237,13 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
42244237 test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {8192 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
42254238 test_cases.emplace_back (new test_cpy (GGML_TYPE_F32, GGML_TYPE_F32, {3072 , 512 , 2 , 1 }, {0 , 2 , 1 , 3 }));
42264239
4227- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4228- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , 1 .0f , 0 .0f ));
4229- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4230- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , 1 .0f , 0 .0f ));
4231- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4232- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4233- test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , 1 .0f , 0 .0f ));
4240+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {4096 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4241+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 4096 , 5 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4242+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {1024 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4243+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 1024 , 10 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4244+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {256 , 256 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4245+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {64 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
4246+ test_cases.emplace_back (new test_soft_max (GGML_TYPE_F32, {77 , 64 , 20 , 1 }, false , GGML_TYPE_F32, 1 .0f , 0 .0f ));
42344247
42354248 test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {32 , 10 , 1 , 1 }));
42364249 test_cases.emplace_back (new test_argmax (GGML_TYPE_F32, {1024 , 10 , 1 , 1 }));
0 commit comments