Skip to content

Commit 6f2f41c

Browse files
pytorchbotdrisspg
andauthored
[FlexAttention] explicilty create grad_q w/ strides (pytorch#153641)
[FlexAttention] explicilty create grad_q w/ strides (pytorch#152641) Fixes: pytorch#147463 There is a mismatch between inductor's lowering for empty_like and it does not match the behavior of eager. The strides do not match preserve format pytorch#144699 Pull Request resolved: pytorch#152641 Approved by: https://github.com/xmfan (cherry picked from commit a6ea63a) Co-authored-by: drisspg <[email protected]>
1 parent 0073e33 commit 6f2f41c

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

torch/_higher_order_ops/flex_attention.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -780,11 +780,12 @@ def sdpa_dense_backward(
780780
]:
781781
from torch._dynamo._trace_wrapped_higher_order_op import TransformGetItemToIndex
782782

783-
Bq, _, _, qk_head_dim = query.shape
783+
Bq, Hq, seq_len_q, qk_head_dim = query.shape
784784
Bkv, Hkv, seq_len_kv, v_head_dim = value.shape
785785

786786
# Get outputs before calling repeat interleave and permute to input stride orders
787-
actual_grad_query = torch.empty_like(query)
787+
actual_grad_query = query.new_empty((Bq, Hq, seq_len_q, qk_head_dim))
788+
actual_grad_query = _permute_strides(actual_grad_query, query.stride())
788789

789790
actual_grad_key = key.new_empty((Bq, Hkv, seq_len_kv, qk_head_dim))
790791
actual_grad_key = _permute_strides(actual_grad_key, key.stride())

torch/_inductor/kernel/flex_attention.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
_full,
3939
check_and_broadcast_indices,
4040
empty,
41-
empty_like,
4241
empty_strided,
4342
expand,
4443
index_output_size_and_inner_fn,
@@ -2524,7 +2523,14 @@ def flex_attention_backward(*args, **kwargs):
25242523
grad_lse_exp2, delta = maybe_realize([grad_lse_exp2, delta])
25252524

25262525
# # see NOTE:[TritonTemplates with multiple outputs]
2527-
grad_query = empty_like(query)
2526+
query_size = [Bq, Hq, seq_len_q, qk_head_dim]
2527+
grad_query_strides = infer_dense_strides(query_size, query.get_stride())
2528+
grad_query = empty_strided(
2529+
query_size,
2530+
stride=[sympy.sympify(s) for s in grad_query_strides],
2531+
dtype=query.get_dtype(),
2532+
device=query.get_device(),
2533+
)
25282534

25292535
# Construct output layout with stride order matching value
25302536
value_size = [Bq, Hkv, seq_len_kv, v_head_dim]

0 commit comments

Comments
 (0)