@@ -40,6 +40,9 @@ def is_blackwell():
4040 return is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10
4141
4242
43+ # FIXME: Revert temporary source code modification done in last commit of PR #4399.
44+
45+
4346@triton .jit
4447def _attn_fwd_inner (acc , l_i , m_i , q , #
4548 desc_k , desc_v , #
@@ -65,7 +68,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
6568 for start_n in tl .range (lo , hi , BLOCK_N , warp_specialize = warp_specialize ):
6669 start_n = tl .multiple_of (start_n , BLOCK_N )
6770 # -- compute qk ----
68- k = desc_k .load ([offsetk_y , 0 ]). T
71+ k = desc_k .load ([0 , offsetk_y ])
6972 qk = tl .dot (q , k )
7073 if STAGE == 2 :
7174 mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
@@ -83,7 +86,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
8386 acc = acc * alpha [:, None ]
8487 # prepare p and v for the dot
8588 if dtype == tl .float8e5 :
86- v = desc_v .load ([0 , offsetv_y ]). T
89+ v = desc_v .load ([offsetv_y , 0 ])
8790 else :
8891 v = desc_v .load ([offsetv_y , 0 ])
8992 p = p .to (dtype )
@@ -176,13 +179,13 @@ def _attn_fwd(sm_scale, M, #
176179 desc_q = _maybe_make_tensor_desc (desc_q , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
177180 block_shape = [BLOCK_M , HEAD_DIM ])
178181 if FP8_OUTPUT :
179- desc_v = _maybe_make_tensor_desc (desc_v , shape = [HEAD_DIM , y_dim ], strides = [N_CTX , 1 ],
180- block_shape = [HEAD_DIM , BLOCK_N ])
182+ desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [1 , N_CTX ],
183+ block_shape = [BLOCK_N , HEAD_DIM ])
181184 else :
182185 desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
183186 block_shape = [BLOCK_N , HEAD_DIM ])
184- desc_k = _maybe_make_tensor_desc (desc_k , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
185- block_shape = [BLOCK_N , HEAD_DIM ])
187+ desc_k = _maybe_make_tensor_desc (desc_k , shape = [HEAD_DIM , y_dim ], strides = [1 , HEAD_DIM ],
188+ block_shape = [HEAD_DIM , BLOCK_N ])
186189 desc_o = _maybe_make_tensor_desc (desc_o , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
187190 block_shape = [BLOCK_M , HEAD_DIM ])
188191
0 commit comments