@@ -104,16 +104,16 @@ void main() {
104104 tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
105105 tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
106106
107- coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK , gl_MatrixUseAccumulator> Q;
108- coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK , gl_MatrixUseA> Qf16;
107+ coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, HSK_pad , gl_MatrixUseAccumulator> Q;
108+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad , gl_MatrixUseA> Qf16;
109109
110110 uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
111- coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK ));
111+ coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad ));
112112
113- Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK , gl_MatrixUseA>(Q);
113+ Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSK_pad , gl_MatrixUseA>(Q);
114114 Qf16 *= float16_t(p.scale);
115115
116- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator>(0);
116+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator>(0);
117117
118118 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
119119
@@ -140,10 +140,10 @@ void main() {
140140
141141 coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
142142
143- coopmat<float16_t, gl_ScopeWorkgroup, HSK , Bc, gl_MatrixUseB> K_T;
143+ coopmat<float16_t, gl_ScopeWorkgroup, HSK_pad , Bc, gl_MatrixUseB> K_T;
144144
145145 uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
146- coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK ), tensorViewTranspose DECODEFUNC);
146+ coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad ), tensorViewTranspose DECODEFUNC);
147147 S = coopMatMulAdd(Qf16, K_T, S);
148148
149149 if (p.logit_softcap != 0.0f) {
@@ -208,31 +208,31 @@ void main() {
208208 rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
209209 rowsum = coopMatMulAdd(P_A, One, rowsum);
210210
211- coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV , gl_MatrixUseB> V;
211+ coopmat<float16_t, gl_ScopeWorkgroup, Bc, HSV_pad , gl_MatrixUseB> V;
212212 uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
213- coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV ) DECODEFUNC);
213+ coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad ) DECODEFUNC);
214214
215215 L = eM*L + rowsum;
216216
217217 // This is the "diagonal" matrix in the paper, but since we do componentwise
218218 // multiply rather than matrix multiply it has the diagonal element smeared
219219 // across the row
220- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> eMdiag;
220+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> eMdiag;
221221
222222 // resize eM by using smear/reduce
223223 coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
224224
225225 // multiply with fp16 accumulation, then add to O.
226- coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator>(0);
226+ coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator>(0);
227227 PV = coopMatMulAdd(P_A, V, PV);
228228
229- O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator>(PV);
229+ O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator>(PV);
230230 }
231231
232232 // If there is split_k, then the split_k resolve shader does the final
233233 // division by L. Store the intermediate O value and per-row m and L values.
234234 if (p.k_num > 1) {
235- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator>(O);
235+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator>(O);
236236
237237 uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num);
238238 coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
@@ -243,16 +243,16 @@ void main() {
243243 return;
244244 }
245245
246- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> Ldiag;
246+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> Ldiag;
247247
248248 // resize L by using smear/reduce
249249 coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
250250
251251 if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) {
252- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> S;
252+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> S;
253253 coopMatPerElementNV(S, S, perElemOpGetSink, iq2);
254254
255- coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> Mr;
255+ coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> Mr;
256256
257257 // resize M by using smear/reduce
258258 coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce);
@@ -285,7 +285,7 @@ void main() {
285285
286286 uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
287287
288- coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV , gl_MatrixUseAccumulator>(O);
288+ coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad , gl_MatrixUseAccumulator>(O);
289289 if (p.gqa_ratio > 1) {
290290 coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
291291 } else {
@@ -295,6 +295,6 @@ void main() {
295295 // permute dimensions
296296 tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
297297
298- coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV ), tensorViewPermute);
298+ coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad ), tensorViewPermute);
299299 }
300300}
0 commit comments