@@ -14,8 +14,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
1414 FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]);
1515 buf_a[buf_idx ] = aa.xy;
1616 buf_a[buf_idx + 1] = aa.zw;
17- #else // LOAD_VEC_A == 2
18- const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
17+ #else // LOAD_VEC_BATCH_A == 2
18+ const uint idx = pos_a + col * p.stride_a + row * 2;
1919 const uint buf_idx = col * SHMEM_STRIDE + row;
2020 if (idx_m < p.M && block + row * 2 + 1 < end_k) {
2121 buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx],
@@ -33,8 +33,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
3333 FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx]));
3434 buf_a[buf_idx ] = aa.xy;
3535 buf_a[buf_idx + 1] = aa.zw;
36- #else // LOAD_VEC_A == 2
37- const uint idx = pos_a * 2 + col * p.stride_a + row * 2;
36+ #else // LOAD_VEC_BATCH_A == 2
37+ const uint idx = pos_a + col * p.stride_a + row * 2;
3838 const uint buf_idx = col * SHMEM_STRIDE + row;
3939 if (idx_m < p.M && block + row * 2 + 1 < end_k) {
4040 buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]),
@@ -500,8 +500,8 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
500500#endif
501501 buf_b[buf_idx + 0] = bb.xy;
502502 buf_b[buf_idx + 1] = bb.zw;
503- #else // LOAD_VEC_B == 2
504- const uint idx = pos_b * 2 + col * p.stride_b + row * 2;
503+ #else // LOAD_VEC_BATCH_B == 2
504+ const uint idx = pos_b + col * p.stride_b + row * 2;
505505 const uint buf_idx = col * SHMEM_STRIDE + row;
506506 if (idx_n < p.N && block + row * 2 + 1 < end_k) {
507507 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
@@ -536,17 +536,17 @@ void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uin
536536#endif
537537 buf_b[buf_idx + 0] = bb.xy;
538538 buf_b[buf_idx + 1] = bb.zw;
539- #else // LOAD_VEC_B == 2
539+ #else // LOAD_VEC_BATCH_B == 2
540540 const uint row_i = ic * BN + col;
541541 const uint buf_idx = col * SHMEM_STRIDE + row;
542542 if (row_i < _ne1 && block + row * 2 + 1 < end_k) {
543543 const u16vec2 row_idx = row_ids[col];
544- const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
544+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
545545 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]),
546546 TO_FLOAT_TYPE(data_b[idx + 1]));
547547 } else if (row_i < _ne1 && block + row * 2 < end_k) {
548548 const u16vec2 row_idx = row_ids[col];
549- const uint idx = pos_b * 2 + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
549+ const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2;
550550 buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f);
551551 } else {
552552 buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f);
0 commit comments