@@ -40,6 +40,9 @@ def is_blackwell():
40
40
return is_cuda () and torch .cuda .get_device_capability ()[0 ] == 10
41
41
42
42
43
+ # FIXME: Revert temporary source code modification done in last commit of PR #4399.
44
+
45
+
43
46
@triton .jit
44
47
def _attn_fwd_inner (acc , l_i , m_i , q , #
45
48
desc_k , desc_v , #
@@ -65,7 +68,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
65
68
for start_n in tl .range (lo , hi , BLOCK_N , warp_specialize = warp_specialize ):
66
69
start_n = tl .multiple_of (start_n , BLOCK_N )
67
70
# -- compute qk ----
68
- k = desc_k .load ([offsetk_y , 0 ]). T
71
+ k = desc_k .load ([0 , offsetk_y ])
69
72
qk = tl .dot (q , k )
70
73
if STAGE == 2 :
71
74
mask = offs_m [:, None ] >= (start_n + offs_n [None , :])
@@ -83,7 +86,7 @@ def _attn_fwd_inner(acc, l_i, m_i, q, #
83
86
acc = acc * alpha [:, None ]
84
87
# prepare p and v for the dot
85
88
if dtype == tl .float8e5 :
86
- v = desc_v .load ([0 , offsetv_y ]). T
89
+ v = desc_v .load ([offsetv_y , 0 ])
87
90
else :
88
91
v = desc_v .load ([offsetv_y , 0 ])
89
92
p = p .to (dtype )
@@ -176,13 +179,13 @@ def _attn_fwd(sm_scale, M, #
176
179
desc_q = _maybe_make_tensor_desc (desc_q , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
177
180
block_shape = [BLOCK_M , HEAD_DIM ])
178
181
if FP8_OUTPUT :
179
- desc_v = _maybe_make_tensor_desc (desc_v , shape = [HEAD_DIM , y_dim ], strides = [N_CTX , 1 ],
180
- block_shape = [HEAD_DIM , BLOCK_N ])
182
+ desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [1 , N_CTX ],
183
+ block_shape = [BLOCK_N , HEAD_DIM ])
181
184
else :
182
185
desc_v = _maybe_make_tensor_desc (desc_v , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
183
186
block_shape = [BLOCK_N , HEAD_DIM ])
184
- desc_k = _maybe_make_tensor_desc (desc_k , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
185
- block_shape = [BLOCK_N , HEAD_DIM ])
187
+ desc_k = _maybe_make_tensor_desc (desc_k , shape = [HEAD_DIM , y_dim ], strides = [1 , HEAD_DIM ],
188
+ block_shape = [HEAD_DIM , BLOCK_N ])
186
189
desc_o = _maybe_make_tensor_desc (desc_o , shape = [y_dim , HEAD_DIM ], strides = [HEAD_DIM , 1 ],
187
190
block_shape = [BLOCK_M , HEAD_DIM ])
188
191
0 commit comments