Skip to content

Commit c7fea1b

Browse files
authored
Fix bwd Mode in flash-attention.py (#772)
1 parent 9a32ed0 commit c7fea1b

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

python/perf-kernels/flash-attention.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1168,7 +1168,8 @@ def forward(ctx, q, k, v, o, metadata: MetaData):
11681168
return o, encoded_softmax, attn_fwd.best_config
11691169

11701170
@staticmethod
1171-
def backward(ctx, do, _):
1171+
def backward(ctx, *gradients):
1172+
do = gradients[0]
11721173
if torch.version.hip is not None:
11731174
BLOCK = 64
11741175
else:
@@ -1936,6 +1937,14 @@ def run_benchmark(custom, args):
19361937
else:
19371938
x_vals_list = nonvarlen_benchmark_configs()
19381939

1940+
if mode == 'bwd':
1941+
# Only those with N_CTX_Q == N_CTX_K work
1942+
new_x = []
1943+
for v in x_vals_list:
1944+
if v[-1] == v[-2]:
1945+
new_x.append(v)
1946+
x_vals_list = new_x
1947+
19391948
if args.model:
19401949
x_vals_list = model_benchmark_configs(args)
19411950
x_names = ['model', 'BATCH', 'HQ', 'HK', 'N_CTX_Q', 'N_CTX_K', 'D_HEAD']
@@ -2013,7 +2022,7 @@ def bench_flash_attention(BATCH, HQ, HK, N_CTX_Q, N_CTX_K, D_HEAD, dtype, causal
20132022
input_metadata.set_persistent(args.persistent)
20142023
fn = lambda: attention(q, k, v, o, input_metadata)
20152024
if mode == 'bwd':
2016-
o, _ = fn()
2025+
o, _, _ = fn()
20172026
do = torch.randn_like(o)
20182027
fn = lambda: o.backward(do, retain_graph=True)
20192028

0 commit comments

Comments
 (0)