@@ -1083,93 +1083,6 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
10831083 args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
10841084 self.assertEqual(gm(*args), m(*args))
10851085
1086- # stride() is called for an undefined tensor
1087- @testing.expectedFailureCppRuntimeNonStrict
1088- def test_native_multi_attention_head(self):
1089- embed_dim = 64
1090- num_heads = 4
1091- bs = 16
1092- sl = 8
1093- device = "cpu"
1094-
1095- q = 6 * torch.rand(bs, sl, embed_dim, device=device, dtype=torch.float32) - 3
1096- k = q
1097- v = q
1098-
1099- qkv = torch.nn.Linear(
1100- embed_dim, 3 * embed_dim, device=device, dtype=torch.float32
1101- )
1102- proj = torch.nn.Linear(embed_dim, embed_dim, device=device, dtype=torch.float32)
1103-
1104- class NativeMHA(torch.nn.Module):
1105- def __init__(
1106- self,
1107- embed_dim,
1108- num_heads,
1109- qkv,
1110- proj,
1111- need_weights,
1112- average_attn_weights,
1113- mask_type,
1114- ):
1115- super().__init__()
1116- self.qkv = qkv
1117- self.proj = proj
1118- self.embed_dim = embed_dim
1119- self.num_heads = num_heads
1120- self.need_weights = need_weights
1121- self.average_attn_weights = average_attn_weights
1122- self.mask_type = mask_type
1123-
1124- def forward(self, q, k, v, key_padding_mask):
1125- return torch._native_multi_head_attention(
1126- q,
1127- k,
1128- v,
1129- self.embed_dim,
1130- self.num_heads,
1131- self.qkv.weight,
1132- self.qkv.bias,
1133- self.proj.weight,
1134- self.proj.bias,
1135- key_padding_mask,
1136- need_weights=False,
1137- average_attn_weights=False,
1138- mask_type=1, # mask_type = 1 => src_key_padding_mask, mask_type = 0 => src_mask
1139- )
1140-
1141- for mask_type in (0, 1):
1142- for need_weights in (True, False):
1143- for average_attn_weights in (True, False):
1144- npt = NativeMHA(
1145- embed_dim=embed_dim,
1146- num_heads=num_heads,
1147- qkv=qkv,
1148- proj=proj,
1149- need_weights=need_weights,
1150- average_attn_weights=average_attn_weights,
1151- mask_type=mask_type,
1152- )
1153- sample_input = (q, k, v, None)
1154-
1155- ep = export(
1156- npt,
1157- args=sample_input,
1158- dynamic_shapes={
1159- "q": {
1160- 0: Dim("dim0_q", max=1024),
1161- },
1162- "k": {
1163- 0: Dim("dim0_k", max=1024),
1164- },
1165- "v": {
1166- 0: Dim("dim0_v", max=1024),
1167- },
1168- "key_padding_mask": None,
1169- },
1170- )
1171- self.assertEqual(ep.module()(*sample_input), npt(*sample_input))
1172-
11731086 def test_unused_constant(self):
11741087 class M(torch.nn.Module):
11751088 def forward(self, x):
0 commit comments