@@ -154,15 +154,31 @@ void main() {
154154 }
155155
156156 if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
157- tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
158- tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
159- tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
157+ bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
160158
161- coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
159+ if (nem1_bounds_check) {
160+ tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
161+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
162+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
162163
163- coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br , Br, j * Bc, Bc)) ;
164+ coopmat<float16_t, gl_ScopeWorkgroup , Br, Bc, gl_MatrixUseAccumulator> mv ;
164165
165- S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
166+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
167+
168+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
169+ } else {
170+ tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
171+ // Don't clamp against nem1 when GQA is enabled
172+ uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1;
173+ tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV);
174+ tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
175+
176+ coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
177+
178+ coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
179+
180+ S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
181+ }
166182 }
167183
168184 // Clear padding elements to -inf, so they don't contribute to rowmax
0 commit comments