Skip to content

Commit a2f34bd

Browse files
Revert "Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (pytorch#164923)"
This reverts commit 3401665. Reverted pytorch#164923 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](pytorch#164923 (comment)))
1 parent a63ab0b commit a2f34bd

File tree

3 files changed

+6
-90
lines changed

3 files changed

+6
-90
lines changed

test/dynamo/test_repros.py

Lines changed: 0 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@
4646
from torch._dynamo.debug_utils import same_two_models
4747
from torch._dynamo.testing import (
4848
CompileCounter,
49-
CompileCounterWithBackend,
5049
EagerAndRecordGraphs,
5150
rand_strided,
5251
same,
@@ -55,7 +54,6 @@
5554
)
5655
from torch._inductor.utils import fresh_cache
5756
from torch.nn import functional as F
58-
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
5957
from torch.profiler import profile, ProfilerActivity
6058
from torch.testing._internal.common_cuda import (
6159
PLATFORM_SUPPORTS_FLASH_ATTENTION,
@@ -7371,67 +7369,6 @@ def fn():
73717369
)
73727370
self.assertEqual(explain_output.break_reasons[0].reason, expected_msg)
73737371

7374-
@parametrize("backend", ["eager", "inductor"])
7375-
def test_issue164247(self, backend: str):
7376-
if backend == "inductor" and torch._dynamo.config.dynamic_shapes:
7377-
raise unittest.SkipTest(
7378-
"Skip only in dynamic-shapes wrapper (known issue #157612)"
7379-
)
7380-
7381-
class MixedFakeModeModel(nn.Module):
7382-
def __init__(self, dim=64):
7383-
super().__init__()
7384-
self.dim = dim
7385-
self.lin = torch.nn.Linear(64, 64)
7386-
7387-
def forward(self, x):
7388-
batch_size, seq_len, _ = x.shape
7389-
7390-
# Process input first - this creates fake tensors in export's fake mode
7391-
processed = self.lin(x)
7392-
7393-
# Create some computation that depends on processed tensor
7394-
intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len)
7395-
7396-
def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
7397-
threshold = intermediate[
7398-
batch_idx, q_idx % seq_len
7399-
] # Access the captured tensor
7400-
return (kv_idx <= q_idx) & (threshold > 0)
7401-
7402-
block_mask = create_block_mask(
7403-
mask_mod=dynamic_mask_function,
7404-
B=batch_size,
7405-
H=None,
7406-
Q_LEN=seq_len,
7407-
KV_LEN=seq_len,
7408-
device=x.device,
7409-
_compile=False,
7410-
)
7411-
q = processed.view(batch_size, 1, seq_len, self.dim)
7412-
k = processed.view(batch_size, 1, seq_len, self.dim)
7413-
v = processed.view(batch_size, 1, seq_len, self.dim)
7414-
7415-
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
7416-
out = flex_attention(q, k, v, block_mask=block_mask)
7417-
7418-
return out
7419-
7420-
backend_counter = CompileCounterWithBackend(backend)
7421-
model = MixedFakeModeModel()
7422-
compiled = torch.compile(model, backend=backend_counter, fullgraph=True)
7423-
7424-
if backend == "inductor":
7425-
# A known InductorError Issue https://github.com/pytorch/pytorch/issues/157612
7426-
with self.assertRaises(RuntimeError):
7427-
compiled(torch.randn(2, 128, 64))
7428-
else:
7429-
compiled(torch.randn(2, 128, 64))
7430-
7431-
# One graph, so no graph breaks
7432-
self.assertEqual(backend_counter.frame_count, 1)
7433-
self.assertEqual(len(backend_counter.graphs), 1)
7434-
74357372

74367373
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
74377374
def test_sub_alpha_scalar_repro(self, device):

torch/_dynamo/variables/functions.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1320,21 +1320,9 @@ def has_closure(self):
13201320

13211321
def const_getattr(self, tx, name):
13221322
if name == "__name__":
1323-
return self.get_name()
1324-
if name == "__code__":
1325-
return self.get_code()
1326-
if name == "__defaults__":
1327-
d = getattr(self, "defaults", None)
1328-
return d.as_python_constant() if d else None
1323+
return self.fn_name.as_python_constant()
13291324
return super().const_getattr(tx, name)
13301325

1331-
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
1332-
if name == "__code__":
1333-
return variables.ConstantVariable.create(hasattr(self, "code"))
1334-
if name == "__defaults__":
1335-
return variables.ConstantVariable.create(hasattr(self, "defaults"))
1336-
return super().call_obj_hasattr(tx, name)
1337-
13381326
def has_self(self):
13391327
return False
13401328

torch/nn/attention/flex_attention.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -266,20 +266,11 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
266266
considered as a score_mod function. If the function has 4 positional arguments, it is
267267
considered as a mask function.
268268
"""
269-
if hasattr(fn, "__code__"):
270-
code = fn.__code__
271-
num_positional_total = code.co_argcount
272-
defaults = ()
273-
if hasattr(fn, "__defaults__"):
274-
defaults = fn.__defaults__ or ()
275-
num_defaults = len(defaults)
276-
num_positional_args = num_positional_total - num_defaults
277-
else:
278-
num_positional_args = sum(
279-
1
280-
for param in inspect.signature(fn).parameters.values()
281-
if param.default is inspect.Parameter.empty
282-
)
269+
num_positional_args = sum(
270+
1
271+
for param in inspect.signature(fn).parameters.values()
272+
if param.default is inspect.Parameter.empty
273+
)
283274
assert num_positional_args == 5 or num_positional_args == 4
284275
if num_positional_args == 5:
285276
return _ModificationType.SCORE_MOD

0 commit comments

Comments
 (0)