Skip to content

Commit 070688f

Browse files
eellisonRyo-not-rio
authored andcommitted
Loosen last dim contiguity for sdpa constraint to include last dim 0,1 (pytorch#139787)
Previously we were checking for a last dim with stride == 1. When the size is <= 1 that also is sufficient because the stride is insignificant. Fix for pytorch#138317 Pull Request resolved: pytorch#139787 Approved by: https://github.com/drisspg
1 parent 0515efc commit 070688f

File tree

2 files changed

+72
-3
lines changed

2 files changed

+72
-3
lines changed

test/inductor/test_cuda_repro.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,6 +1242,72 @@ def outer_reduce(x):
12421242
self.assertEqual(outer_reduce(a), out)
12431243
self.assertTrue("for roffset" not in code)
12441244

1245+
def test_scaled_dot_product_efficient_attention_backward(self):
1246+
from torch import nn, Tensor
1247+
1248+
class SelfAttention(nn.Module):
1249+
def __init__(
1250+
self,
1251+
num_attention_heads: int = 12,
1252+
hidden_size: int = 768,
1253+
attention_probs_dropout_prob: float = 0.1,
1254+
):
1255+
super().__init__()
1256+
1257+
self.num_attention_heads = num_attention_heads
1258+
self.attention_head_size = hidden_size // num_attention_heads
1259+
1260+
self.query = nn.Linear(hidden_size, hidden_size)
1261+
self.key = nn.Linear(hidden_size, hidden_size)
1262+
self.value = nn.Linear(hidden_size, hidden_size)
1263+
1264+
self.dropout_prob = attention_probs_dropout_prob
1265+
1266+
def transpose_for_scores(self, x: Tensor) -> Tensor:
1267+
new_x_shape = x.size()[:-1] + (
1268+
self.num_attention_heads,
1269+
self.attention_head_size,
1270+
)
1271+
return x.view(new_x_shape).permute(0, 2, 1, 3)
1272+
1273+
def forward(self, hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
1274+
query_layer = self.transpose_for_scores(self.query(hidden_states))
1275+
key_layer = self.transpose_for_scores(self.key(hidden_states))
1276+
value_layer = self.transpose_for_scores(self.value(hidden_states))
1277+
1278+
attn_output = torch.nn.functional.scaled_dot_product_attention(
1279+
query_layer,
1280+
key_layer,
1281+
value_layer,
1282+
attn_mask=attention_mask,
1283+
dropout_p=self.dropout_prob if self.training else 0.0,
1284+
is_causal=False,
1285+
)
1286+
return attn_output
1287+
1288+
device = torch.device("cuda")
1289+
num_attention_heads = 8
1290+
hidden_size = 512
1291+
attention_probs_dropout_prob = 0.0
1292+
model = SelfAttention(
1293+
num_attention_heads=num_attention_heads,
1294+
hidden_size=hidden_size,
1295+
attention_probs_dropout_prob=attention_probs_dropout_prob,
1296+
).to(device)
1297+
1298+
model = torch.compile(model)
1299+
1300+
# runs without failure
1301+
batch_size = 8
1302+
length = 1
1303+
inputs_embeds = torch.randn(batch_size, length, hidden_size, device=device)
1304+
attention_mask = torch.ones(batch_size, 1, length, length, device=device)
1305+
attn_output = model(hidden_states=inputs_embeds, attention_mask=attention_mask)[
1306+
0
1307+
]
1308+
loss = attn_output.mean()
1309+
loss.backward()
1310+
12451311
def test_non_contiguous_unaligned_input_indices(self):
12461312
from torch._inductor.compile_fx import remove_unaligned_input_idxs
12471313

torch/_inductor/lowering.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2349,9 +2349,12 @@ def is_aligned_realized_tensor(x):
23492349
(V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0
23502350
for i in range(len(x.get_stride()) - 1)
23512351
)
2352-
return (
2353-
V.graph.sizevars.size_hint(x.get_stride()[-1])
2354-
) == 1 and aligned_strides
2352+
# if the last dim size is <= 1, stride doesnt matter
2353+
aligned_last_dim = (
2354+
V.graph.sizevars.size_hint(x.get_stride()[-1]) == 1
2355+
or V.graph.sizevars.size_hint(x.get_size()[-1]) <= 1
2356+
)
2357+
return aligned_last_dim and aligned_strides
23552358

23562359
try:
23572360
arg.get_stride()

0 commit comments

Comments
 (0)