Skip to content

Commit 3401665

Browse files
jmaczanpytorchmergebot
authored andcommitted
Patch the flex_attention._get_mod_type to not use inspect.signature when computing num_positional_args (an alternative fix for flex attention graph break on create_block_mask) (pytorch#164923)
The initial fix for inspect.signature uses not a right approach (pytorch#164349 (review)). As @williamwen42 suggests (pytorch#164349 (comment)) we can just for now get rid of `inspect.signature` call in flex_attention to resolve this high priority issue (pytorch#164247 (comment)). In this PR I did exactly this - limited the scope of fix to just computing `num_positional_args` in `flex_attention._get_mod_type` based on properties returned by `NestedUserFunctionVariable.const_getattr` (some were missing so I added them) Fixes pytorch#164247 Pull Request resolved: pytorch#164923 Approved by: https://github.com/williamwen42
1 parent 8c60f4a commit 3401665

File tree

3 files changed

+90
-6
lines changed

3 files changed

+90
-6
lines changed

test/dynamo/test_repros.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from torch._dynamo.debug_utils import same_two_models
4747
from torch._dynamo.testing import (
4848
CompileCounter,
49+
CompileCounterWithBackend,
4950
EagerAndRecordGraphs,
5051
rand_strided,
5152
same,
@@ -54,6 +55,7 @@
5455
)
5556
from torch._inductor.utils import fresh_cache
5657
from torch.nn import functional as F
58+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
5759
from torch.profiler import profile, ProfilerActivity
5860
from torch.testing._internal.common_cuda import (
5961
PLATFORM_SUPPORTS_FLASH_ATTENTION,
@@ -7369,6 +7371,67 @@ def fn():
73697371
)
73707372
self.assertEqual(explain_output.break_reasons[0].reason, expected_msg)
73717373

7374+
@parametrize("backend", ["eager", "inductor"])
7375+
def test_issue164247(self, backend: str):
7376+
if backend == "inductor" and torch._dynamo.config.dynamic_shapes:
7377+
raise unittest.SkipTest(
7378+
"Skip only in dynamic-shapes wrapper (known issue #157612)"
7379+
)
7380+
7381+
class MixedFakeModeModel(nn.Module):
7382+
def __init__(self, dim=64):
7383+
super().__init__()
7384+
self.dim = dim
7385+
self.lin = torch.nn.Linear(64, 64)
7386+
7387+
def forward(self, x):
7388+
batch_size, seq_len, _ = x.shape
7389+
7390+
# Process input first - this creates fake tensors in export's fake mode
7391+
processed = self.lin(x)
7392+
7393+
# Create some computation that depends on processed tensor
7394+
intermediate = processed.sum(dim=-1).detach() # Shape: (batch, seq_len)
7395+
7396+
def dynamic_mask_function(batch_idx, head_idx, q_idx, kv_idx):
7397+
threshold = intermediate[
7398+
batch_idx, q_idx % seq_len
7399+
] # Access the captured tensor
7400+
return (kv_idx <= q_idx) & (threshold > 0)
7401+
7402+
block_mask = create_block_mask(
7403+
mask_mod=dynamic_mask_function,
7404+
B=batch_size,
7405+
H=None,
7406+
Q_LEN=seq_len,
7407+
KV_LEN=seq_len,
7408+
device=x.device,
7409+
_compile=False,
7410+
)
7411+
q = processed.view(batch_size, 1, seq_len, self.dim)
7412+
k = processed.view(batch_size, 1, seq_len, self.dim)
7413+
v = processed.view(batch_size, 1, seq_len, self.dim)
7414+
7415+
out = torch.compile(flex_attention)(q, k, v, block_mask=block_mask)
7416+
out = flex_attention(q, k, v, block_mask=block_mask)
7417+
7418+
return out
7419+
7420+
backend_counter = CompileCounterWithBackend(backend)
7421+
model = MixedFakeModeModel()
7422+
compiled = torch.compile(model, backend=backend_counter, fullgraph=True)
7423+
7424+
if backend == "inductor":
7425+
# A known InductorError Issue https://github.com/pytorch/pytorch/issues/157612
7426+
with self.assertRaises(RuntimeError):
7427+
compiled(torch.randn(2, 128, 64))
7428+
else:
7429+
compiled(torch.randn(2, 128, 64))
7430+
7431+
# One graph, so no graph breaks
7432+
self.assertEqual(backend_counter.frame_count, 1)
7433+
self.assertEqual(len(backend_counter.graphs), 1)
7434+
73727435

73737436
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
73747437
def test_sub_alpha_scalar_repro(self, device):

torch/_dynamo/variables/functions.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1320,9 +1320,21 @@ def has_closure(self):
13201320

13211321
def const_getattr(self, tx, name):
13221322
if name == "__name__":
1323-
return self.fn_name.as_python_constant()
1323+
return self.get_name()
1324+
if name == "__code__":
1325+
return self.get_code()
1326+
if name == "__defaults__":
1327+
d = getattr(self, "defaults", None)
1328+
return d.as_python_constant() if d else None
13241329
return super().const_getattr(tx, name)
13251330

1331+
def call_obj_hasattr(self, tx: "InstructionTranslator", name):
1332+
if name == "__code__":
1333+
return variables.ConstantVariable.create(hasattr(self, "code"))
1334+
if name == "__defaults__":
1335+
return variables.ConstantVariable.create(hasattr(self, "defaults"))
1336+
return super().call_obj_hasattr(tx, name)
1337+
13261338
def has_self(self):
13271339
return False
13281340

torch/nn/attention/flex_attention.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,11 +266,20 @@ def _get_mod_type(fn: Callable) -> _ModificationType:
266266
considered as a score_mod function. If the function has 4 positional arguments, it is
267267
considered as a mask function.
268268
"""
269-
num_positional_args = sum(
270-
1
271-
for param in inspect.signature(fn).parameters.values()
272-
if param.default is inspect.Parameter.empty
273-
)
269+
if hasattr(fn, "__code__"):
270+
code = fn.__code__
271+
num_positional_total = code.co_argcount
272+
defaults = ()
273+
if hasattr(fn, "__defaults__"):
274+
defaults = fn.__defaults__ or ()
275+
num_defaults = len(defaults)
276+
num_positional_args = num_positional_total - num_defaults
277+
else:
278+
num_positional_args = sum(
279+
1
280+
for param in inspect.signature(fn).parameters.values()
281+
if param.default is inspect.Parameter.empty
282+
)
274283
assert num_positional_args == 5 or num_positional_args == 4
275284
if num_positional_args == 5:
276285
return _ModificationType.SCORE_MOD

0 commit comments

Comments
 (0)