@@ -29,6 +29,10 @@ ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
2929 return max(x, y);
3030}
3131
32+ float16_t maxReduceFp16(const in float16_t x, const in float16_t y) {
33+ return max(x, y);
34+ }
35+
3236ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
3337 return x;
3438}
@@ -142,49 +146,63 @@ void main() {
142146 [[dont_unroll]]
143147 for (uint32_t j = start_j; j < end_j; ++j) {
144148
145- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
146-
147- coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
148-
149- uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
150- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
151- S = coopMatMulAdd(Qf16, K_T, S);
152-
153- if (p.logit_softcap != 0.0f) {
154- [[unroll]]
155- for (int k = 0; k < S.length(); ++k) {
156- S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
157- }
158- }
159-
149+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
160150 if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
161151 bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
162152
163153 if (nem1_bounds_check) {
164154 tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
165155 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
166156 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
157+ tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t
167158
168- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
159+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv, mvmax ;
169160
170161 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
171162
172- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
163+ // skip the block if the mask is entirely -inf
164+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
165+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
166+ continue;
167+ }
173168 } else {
174169 tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
175170 // Don't clamp against nem1 when GQA is enabled
176171 uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
177172 tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
178173 tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
179174
180- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv ;
175+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mvmax ;
181176
182177 coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
183178
184- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
179+ // skip the block if the mask is entirely -inf
180+ coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16);
181+ if (mvmax[0] <= NEG_FLT_MAX_OVER_2) {
182+ continue;
183+ }
185184 }
186185 }
187186
187+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
188+
189+ coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad, Bc, gl_MatrixUseB> K_T;
190+
191+ uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
192+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC);
193+ S = coopMatMulAdd(Qf16, K_T, S);
194+
195+ if (p.logit_softcap != 0.0f) {
196+ [[unroll]]
197+ for (int k = 0; k < S.length(); ++k) {
198+ S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
199+ }
200+ }
201+
202+ if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
203+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
204+ }
205+
188206 // Clear padding elements to -inf, so they don't contribute to rowmax
189207 if (Clamp != 0 &&
190208 ((j + 1) * Bc > KV ||
0 commit comments