Skip to content

Commit aff76c0

Browse files
Revert "Add fake_impl for _native_multi_head_attention (pytorch#163167)"
This reverts commit 27164b6. Reverted pytorch#163167 on behalf of https://github.com/malfet due to This broke in inductor-cpu-test, see https://hud.pytorch.org/hud/pytorch/pytorch/1a42656d6c43a9bb7eb90c511884ce451d29422f/1?per_page=50&name_filter=inductor-cpu-test&mergeEphemeralLF=true ([comment](pytorch#163167 (comment)))
1 parent 1a42656 commit aff76c0

File tree

2 files changed

+0
-137
lines changed

2 files changed

+0
-137
lines changed

test/export/test_export.py

Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1083,93 +1083,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
10831083
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
10841084
self.assertEqual(gm(*args), m(*args))
10851085

1086-
# stride() is called for an undefined tensor
1087-
@testing.expectedFailureCppRuntimeNonStrict
1088-
def test_native_multi_attention_head(self):
1089-
embed_dim = 64
1090-
num_heads = 4
1091-
bs = 16
1092-
sl = 8
1093-
device = "cpu"
1094-
1095-
q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
1096-
k = q
1097-
v = q
1098-
1099-
qkv = torch.nn.Linear(
1100-
embed_dim, 3 * embed_dim, device=device, dtype=torch.float32
1101-
)
1102-
proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
1103-
1104-
class NativeMHA(torch.nn.Module):
1105-
def __init__(
1106-
self,
1107-
embed_dim,
1108-
num_heads,
1109-
qkv,
1110-
proj,
1111-
need_weights,
1112-
average_attn_weights,
1113-
mask_type,
1114-
):
1115-
super().__init__()
1116-
self.qkv = qkv
1117-
self.proj = proj
1118-
self.embed_dim = embed_dim
1119-
self.num_heads = num_heads
1120-
self.need_weights = need_weights
1121-
self.average_attn_weights = average_attn_weights
1122-
self.mask_type = mask_type
1123-
1124-
def forward(self, q, k, v, key_padding_mask):
1125-
return torch._native_multi_head_attention(
1126-
q,
1127-
k,
1128-
v,
1129-
self.embed_dim,
1130-
self.num_heads,
1131-
self.qkv.weight,
1132-
self.qkv.bias,
1133-
self.proj.weight,
1134-
self.proj.bias,
1135-
key_padding_mask,
1136-
need_weights=False,
1137-
average_attn_weights=False,
1138-
mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
1139-
)
1140-
1141-
for mask_type in (0, 1):
1142-
for need_weights in (True, False):
1143-
for average_attn_weights in (True, False):
1144-
npt = NativeMHA(
1145-
embed_dim=embed_dim,
1146-
num_heads=num_heads,
1147-
qkv=qkv,
1148-
proj=proj,
1149-
need_weights=need_weights,
1150-
average_attn_weights=average_attn_weights,
1151-
mask_type=mask_type,
1152-
)
1153-
sample_input = (q, k, v, None)
1154-
1155-
ep = export(
1156-
npt,
1157-
args=sample_input,
1158-
dynamic_shapes={
1159-
"q": {
1160-
0: Dim("dim0_q", max=1024),
1161-
},
1162-
"k": {
1163-
0: Dim("dim0_k", max=1024),
1164-
},
1165-
"v": {
1166-
0: Dim("dim0_v", max=1024),
1167-
},
1168-
"key_padding_mask": None,
1169-
},
1170-
)
1171-
self.assertEqual(ep.module()(*sample_input), npt(*sample_input))
1172-
11731086
def test_unused_constant(self):
11741087
class M(torch.nn.Module):
11751088
def forward(self, x):

torch/_meta_registrations.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -7780,56 +7780,6 @@ def _f(x):
77807780
return _f
77817781

77827782

7783-
# Implementation follows cuda implementation native_multi_head_attention_cuda
7784-
@register_meta(aten._native_multi_head_attention.default)
7785-
def native_multi_head_attention_fake(
7786-
query,
7787-
key,
7788-
value,
7789-
embed_dim,
7790-
num_head,
7791-
qkv_weight,
7792-
qkv_bias,
7793-
proj_weight,
7794-
proj_bias,
7795-
mask=None,
7796-
need_weights=True,
7797-
average_attn_weights=True,
7798-
mask_type=None,
7799-
):
7800-
if query.is_nested or key.is_nested or value.is_nested:
7801-
raise NotImplementedError(
7802-
"_native_multi_head_attention fake implementation does not support nested tensors"
7803-
)
7804-
7805-
if query.numel() == 0:
7806-
return (query.new_empty(query.shape), query.new_empty(0))
7807-
7808-
B = query.size(0) # B: batch size
7809-
T = query.size(1) # T: target sequence length
7810-
7811-
# In native_multi_head_attention_cuda,
7812-
# we have proj = transform0213_gemm_nt_bias(attn_ctx, proj_weight, proj_bias, query)
7813-
# , which does attn_ctx @ proj_weight.T + proj_bias
7814-
# so the last dim of output shape is proj_weight.size(0)
7815-
output_dim = proj_weight.size(0)
7816-
output = query.new_empty(B, T, output_dim)
7817-
7818-
if need_weights:
7819-
if average_attn_weights:
7820-
# When averaging attention weights, shape is [B, T, T] (averaged over heads)
7821-
# T = query seq len, S = key/value seq len
7822-
attn_weights = query.new_empty(B, T, T)
7823-
else:
7824-
# When not averaging, shape is [B, num_head, T, T]
7825-
# T = query seq len, S = key/value seq len
7826-
attn_weights = query.new_empty(B, num_head, T, T)
7827-
else:
7828-
attn_weights = query.new_empty(0)
7829-
7830-
return (output, attn_weights)
7831-
7832-
78337783
def _create_binary_float_meta_func(func):
78347784
@register_meta(func)
78357785
@out_wrapper()

0 commit comments

Comments
 (0)