@@ -131,6 +131,50 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m
131131 }
132132}
133133
134+ // generate an F16 mask where certain blocks are randomly masked with -INF value
135+ static void init_tensor_kq_mask (ggml_tensor * tensor, float min = -1 .0f , float max = 1 .0f ) {
136+ GGML_ASSERT (tensor->type == GGML_TYPE_F16);
137+
138+ GGML_TENSOR_LOCALS ( int32_t , ne, tensor, ne);
139+
140+ std::vector<float > data_f32 (ne0*ne1*ne2*ne3);
141+ std::vector<ggml_fp16_t > data_f16 (ne0*ne1*ne2*ne3);
142+
143+ std::random_device rd;
144+ std::mt19937 gen (rd ());
145+ std::uniform_real_distribution<float > dis (min, max);
146+
147+ for (size_t i = 0 ; i < data_f32.size (); i++) {
148+ data_f32[i] = dis (gen);
149+ }
150+
151+ // block size
152+ const int blck0 = 128 ;
153+ const int blck1 = 16 ;
154+
155+ // number of INF blocks
156+ const int n_inf_blocks = 0.1 *(ne0*ne1*ne2*ne3)/(blck0*blck1);
157+
158+ for (int b = 0 ; b < n_inf_blocks; b++) {
159+ const int p3 = (rd () % ne3);
160+ const int p2 = (rd () % ne2);
161+ const int p1 = (rd () % ne1);
162+ const int p0 = (rd () % ne0);
163+
164+ for (int i1 = 0 ; i1 < blck1 && p1 + i1 < ne1; i1++) {
165+ const int idx = p3*ne2*ne1*ne0 + p2*ne1*ne0 + (p1 + i1)*ne0 + p0;
166+
167+ for (int i0 = 0 ; i0 < blck0 && p0 + i0 < ne0; i0++) {
168+ data_f32[idx + i0] = -INFINITY;
169+ }
170+ }
171+ }
172+
173+ ggml_fp32_to_fp16_row (data_f32.data (), data_f16.data (), ne0*ne1*ne2*ne3);
174+
175+ ggml_backend_tensor_set (tensor, data_f16.data (), 0 , data_f16.size ()*sizeof (ggml_fp16_t ));
176+ }
177+
134178static std::vector<float > tensor_to_float (const ggml_tensor * t) {
135179 std::vector<float > tv;
136180 tv.reserve (ggml_nelements (t));
@@ -5104,6 +5148,8 @@ struct test_flash_attn_ext : public test_case {
51045148 if (strcmp (t->name , " s" ) == 0 ) {
51055149 // make the sink values more noticable in order to trigger a test failure when the implementation is wrong
51065150 init_tensor_uniform (t, -10 .0f , 10 .0f );
5151+ } else if (strcmp (t->name , " m" ) == 0 ) {
5152+ init_tensor_kq_mask (t);
51075153 } else {
51085154 init_tensor_uniform (t);
51095155 }
0 commit comments