@@ -229,16 +229,14 @@ void main() {
229229            uint32_t B_ly    = r_offset + Ar;
230230            uint32_t B_lx    = Ac;
231231            uint32_t K_idx   = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/
232-             float    val;
233-             if (K_idx >= K || CRS_idx_a >= CRS) {
234-                 val = 0.0;
235-             } else {
236232#ifdef TRANSPOSE
237-                  uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
233+             uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1);
238234#else
239-                  uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
235+             uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1);
240236#endif
241-                 val = knl_data[knl_idx];
237+             float    val     = knl_data[knl_idx];
238+             if (K_idx >= K || CRS_idx_a >= CRS) {
239+                 val = 0.0;
242240            }
243241            Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
244242        }
@@ -286,18 +284,16 @@ void main() {
286284            uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1;
287285            uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0;
288286#endif
289-             float val;
287+             uint32_t src_idx =
288+                 min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
289+             float val = src_data[src_idx];
290290            if (CRS_idx_b >= CRS || NPQ_idx >= NPQ
291291                || int32_t(H_idx) < 0 || H_idx >= p.H || int32_t(W_idx) < 0 || W_idx >= p.W
292292#ifdef TRANSPOSE
293293                || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0)
294294#endif
295295                ) {
296296                val = 0.0;
297-             } else {
298-                 uint32_t src_idx =
299-                     min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1);
300-                 val = src_data[src_idx];
301297            }
302298            Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
303299        }
0 commit comments