Skip to content

Commit ad15cc2

Browse files
authored
thunderfx: Update unsupported ctx with new functorch increment/decrement (#2732)
1 parent 2732ecf commit ad15cc2

File tree

2 files changed

+35
-17
lines changed

2 files changed

+35
-17
lines changed

thunder/dynamo/utils.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -399,17 +399,32 @@ def get_nodes_in_unsupported_ctx_regions(gm: torch.fx.GraphModule) -> set[torch.
399399
# NOTE - Currently doesn't ban any ctx (previously used for `no_grad` and `autocast`).
400400

401401
nodes_in_unsupported_ctx_regions: set[torch.fx.Node] = set()
402-
ctx_cnt = 0 # Count of we have seen till now
402+
ctx_stack = [] # Unsupported context.__enter__ we have seen till now.
403+
404+
CTX_ENTER_EXIT_MAP = {
405+
torch._C._functorch._vmap_increment_nesting: torch._C._functorch._vmap_decrement_nesting,
406+
}
407+
408+
# Older version of PyTorch may not have `torch._functorch.predispatch`
409+
if hasattr(torch._functorch, "predispatch"):
410+
CTX_ENTER_EXIT_MAP[torch._functorch.predispatch._vmap_increment_nesting] = (
411+
torch._functorch.predispatch._vmap_decrement_nesting
412+
)
413+
414+
UNSUPPORTED_THUNDER_CTX_ENTER = tuple(CTX_ENTER_EXIT_MAP.keys())
415+
UNSUPPORTED_THUNDER_CTX_EXIT = tuple(CTX_ENTER_EXIT_MAP.values())
403416

404-
UNSUPPORTED_THUNDER_CTX = (torch._C._functorch._vmap_increment_nesting, torch._C._functorch._vmap_decrement_nesting)
405417
for node in gm.graph.nodes:
406-
if node.op == "call_function" and node.target in UNSUPPORTED_THUNDER_CTX:
407-
ctx_cnt += 1
408-
elif node.op == "call_function" and node.target in UNSUPPORTED_THUNDER_CTX:
409-
ctx_cnt -= 1
410-
else:
411-
if ctx_cnt > 0:
412-
nodes_in_unsupported_ctx_regions.add(node)
418+
# All the cases when `node` should be marked as in unsupported ctx region.
419+
if node.op == "call_function" and node.target in UNSUPPORTED_THUNDER_CTX_ENTER:
420+
ctx_stack.append(node.target)
421+
nodes_in_unsupported_ctx_regions.add(node)
422+
elif node.op == "call_function" and node.target in UNSUPPORTED_THUNDER_CTX_EXIT:
423+
enter_fn = ctx_stack.pop()
424+
assert CTX_ENTER_EXIT_MAP[enter_fn] == node.target, "Mismatched ctx enter-exit"
425+
nodes_in_unsupported_ctx_regions.add(node)
426+
elif len(ctx_stack) > 0:
427+
nodes_in_unsupported_ctx_regions.add(node)
413428

414429
return nodes_in_unsupported_ctx_regions
415430

thunder/tests/test_networks.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -357,9 +357,6 @@ def test_quantization():
357357
assert_close(v, sd2[k])
358358

359359

360-
@pytest.mark.skip(
361-
reason="incompatible with transformers >= 4.55.4, see https://github.com/Lightning-AI/lightning-thunder/issues/2726"
362-
)
363360
@thunder.tests.framework.requiresCUDA
364361
def test_thunderfx_mistral_nemo_small():
365362
"""
@@ -402,7 +399,11 @@ def test_thunderfx_mistral_nemo_small():
402399
input_ids = torch.randint(0, config.vocab_size, iid_size, device=device)
403400
attention_mask = torch.ones_like(input_ids)
404401

405-
output = mdl(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
402+
with warnings.catch_warnings():
403+
warnings.filterwarnings(
404+
"ignore", category=FutureWarning, message=r".*`isinstance\(treespec, LeafSpec\)` is deprecated.*"
405+
)
406+
output = mdl(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
406407
logits = output.logits
407408
grad_logits = torch.randn_like(logits)
408409
logits.backward(grad_logits)
@@ -424,9 +425,6 @@ def qwen2():
424425
return [(phi3), (qwen2)]
425426

426427

427-
@pytest.mark.skip(
428-
reason="incompatible with transformers >= 4.55.4, see https://github.com/Lightning-AI/lightning-thunder/issues/2726"
429-
)
430428
@thunder.tests.framework.requiresCUDA
431429
@pytest.mark.parametrize("model_fn", _get_model_config_pairs())
432430
def test_hf_for_nemo(model_fn):
@@ -460,7 +458,11 @@ def test_hf_for_nemo(model_fn):
460458
ref_output = model(input_ids=input_ids, labels=input_ids)
461459
ref_loss = ref_output.loss
462460

463-
compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
461+
with warnings.catch_warnings():
462+
warnings.filterwarnings(
463+
"ignore", category=FutureWarning, message=r".*`isinstance\(treespec, LeafSpec\)` is deprecated.*"
464+
)
465+
compiled_output = compiled_model(input_ids=input_ids, labels=input_ids)
464466
compiled_loss = compiled_output.loss
465467

466468
# Less strict tolerance probably due to different type promotion order for bfloat16
@@ -523,6 +525,7 @@ def test_hf_for_nemo(model_fn):
523525
@requiresCUDA
524526
@pytest.mark.skip(
525527
reason="incompatible with transformers >= 4.55.4, see https://github.com/Lightning-AI/lightning-thunder/issues/2726"
528+
"Error Message: 'DynamicCache' object has no attribute 'get_usable_length'. Did you mean: 'get_seq_length'?"
526529
)
527530
@requiresDeviceMemory(required_memory_bytes=int(0.7 * 1024 * 1024 * 1024))
528531
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)