2929from lightning .fabric import Fabric
3030from lightning .fabric .plugins import FSDPPrecision
3131from lightning .fabric .strategies import FSDPStrategy
32+ from lightning .fabric .utilities .imports import _TORCH_LESS_EQUAL_2_6
3233from lightning .fabric .utilities .load import _load_distributed_checkpoint
3334from lightning .fabric .wrappers import _FabricOptimizer
3435from tests_fabric .helpers .datasets import RandomDataset
@@ -411,8 +412,10 @@ def test_reapply_compile():
411412 fabric .launch ()
412413
413414 model = BoringModel ()
414- # compile_kwargs = {"mode": "reduce-overhead"}
415- compiled_model = torch .compile (model ) # , **compile_kwargs
415+ # currently (PyTorch 2.6) using ruduce-overhead here casues a RuntimeError:
416+ # Error: accessing tensor output of CUDAGraphs that has been overwritten by a subsequent run.
417+ compile_kwargs = {"mode" : "reduce-overhead" } if _TORCH_LESS_EQUAL_2_6 else {}
418+ compiled_model = torch .compile (model , ** compile_kwargs )
416419 torch .compile .reset_mock ()
417420
418421 fabric_model = fabric .setup (compiled_model , _reapply_compile = True )
@@ -421,7 +424,7 @@ def test_reapply_compile():
421424 assert isinstance (fabric_model ._forward_module ._orig_mod , FullyShardedDataParallel )
422425
423426 # Assert we called compile again with the same arguments, but on the FSDP-wrapped module
424- torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod ) # , **compile_kwargs
427+ torch .compile .assert_called_with (fabric_model ._forward_module ._orig_mod , ** compile_kwargs )
425428
426429 assert fabric_model ._original_module == model
427430 assert fabric_model ._forward_module ._orig_mod .module == model
0 commit comments