Skip to content

Commit 2390bc5

Browse files
authored
Fix DeepCompile for PyTorch 2.8/2.9 compatibility (#7755)
PyTorch 2.8 added a new static_lifetime_input_indices parameter to the partition function. This breaks DeepCompile with ZeRO stage 3. This PR updates `partition_recompute_ds_params` to accept `**kwargs` and forward them to the underlying partition function, maintaining backward compatibility with PyTorch 2.6/2.7. DeepCompile works with PyTorch v2.9 when using ZeRO Stage 1 or 2. However, ZeRO Stage 3 is not currently supported on PyTorch v2.9 (it still works on PyTorch <= v2.8). DeepCompile tests are skipped when PyTorch version is v2.9 and ZeRO stage is 3. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent 1d140f8 commit 2390bc5

File tree

5 files changed

+17
-7
lines changed

5 files changed

+17
-7
lines changed

deepspeed/compile/inductor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,10 @@
77

88
try:
99
import torch.utils._pytree as pytree
10-
from torch._functorch.aot_autograd import create_aot_dispatcher_function
1110
from torch._inductor.lowering import register_lowering, fallbacks, add_needs_realized_inputs
1211
from torch._inductor.ir import TensorBox, FallbackKernel, Layout, IRNode
1312
from torch._inductor.virtualized import V
1413
from torch._inductor.scheduler import Scheduler
15-
16-
original_create_aot_dispatcher_function = create_aot_dispatcher_function
1714
except ImportError:
1815
pass
1916

deepspeed/compile/partitioner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,10 +87,10 @@ def get_wrapped_partitioner(
8787
partition_fn,
8888
):
8989

90-
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *,
91-
num_fwd_outputs) -> Tuple[GraphModule, GraphModule]:
90+
def partition_recompute_ds_params(joint_module: GraphModule, _joint_inputs, *, num_fwd_outputs,
91+
**kwargs) -> Tuple[GraphModule, GraphModule]:
9292
if z3_partition:
9393
_recompute_param_aliases(joint_module.graph, param_indices)
94-
return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs)
94+
return partition_fn(joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs, **kwargs)
9595

9696
return partition_recompute_ds_params

deepspeed/compile/util.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626

2727
def is_deepcompile_supported() -> bool:
28-
return required_torch_version(min_version=2.6, max_version=2.7) and get_accelerator().device_name() == "cuda"
28+
return required_torch_version(min_version=2.6, max_version=2.9) and get_accelerator().device_name() == "cuda"
2929

3030

3131
dc_handle = None

deepspeed/runtime/engine.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from deepspeed.runtime import lr_schedules
7878
from deepspeed.utils import groups
7979
from deepspeed.utils import logger, log_dist, log_dist_once, instrument_w_nvtx
80+
from deepspeed.utils.torch import required_torch_version
8081
from deepspeed.utils.z3_leaf_module import apply_zero_leaf_module_config
8182
from deepspeed.utils.timer import NoopTimer, ThroughputTimer, SynchronizedWallClockTimer, \
8283
FORWARD_MICRO_TIMER, BACKWARD_MICRO_TIMER, BACKWARD_INNER_MICRO_TIMER, BACKWARD_REDUCE_MICRO_TIMER, \
@@ -4296,6 +4297,10 @@ def passes_name_to_fn(passes):
42964297
elif self.zero_optimization_stage() == ZeroStageEnum.gradients:
42974298
backend = init_z1(self, backend, compile_config, compile_kwargs, schedule, use_z2=True)
42984299
elif self.zero_optimization_stage() == ZeroStageEnum.weights:
4300+
if required_torch_version(min_version=2.9):
4301+
raise RuntimeError(
4302+
"DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9. "
4303+
"Please use ZeRO stage 1 or 2 with DeepCompile, or disable DeepCompile for ZeRO stage 3.")
42994304
backend = init_z3(self, backend, compile_config, compile_kwargs, schedule)
43004305

43014306
# Hook state must align with whether DeepCompile is active.

tests/unit/v1/compile/test_compile_zero.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,8 @@ class TestDeepCompile(DistributedTest):
8282
def test(self, zero_stage, dtype, deepcompile):
8383
if not required_torch_version(min_version=2.6):
8484
pytest.skip("DeepCompile requires PyTorch >= v2.6")
85+
if zero_stage == 3 and required_torch_version(min_version=2.9):
86+
pytest.skip("DeepCompile with ZeRO stage 3 is not currently supported on PyTorch >= 2.9")
8587

8688
if dtype == torch.bfloat16:
8789
skip_on_arch(min_arch=8)
@@ -123,6 +125,8 @@ def test_padded_shard_handling(self, zero_stage, dtype):
123125
"""Test that parameters with padding (uneven division) work correctly with DeepCompile"""
124126
if not required_torch_version(min_version=2.6):
125127
pytest.skip("DeepCompile requires PyTorch >= v2.6")
128+
if required_torch_version(min_version=2.9):
129+
pytest.skip("DeepCompile with ZeRO stage 3 is not supported on PyTorch >= 2.9")
126130

127131
if get_accelerator().device_name() == "cpu":
128132
pytest.skip("CPU does not support this test yet")
@@ -156,6 +160,8 @@ def test_free_activation_mode(self, zero_stage, dtype):
156160
"""Test that eagerly free activations work correctly and the threshold is configurable"""
157161
if not required_torch_version(min_version=2.6):
158162
pytest.skip("DeepCompile requires PyTorch >= v2.6")
163+
if zero_stage == 3 and required_torch_version(min_version=2.9):
164+
pytest.skip("DeepCompile with ZeRO stage 3 is not supported on PyTorch >= 2.9")
159165

160166
if get_accelerator().device_name() == "cpu":
161167
pytest.skip("CPU does not support this test yet")
@@ -187,6 +193,8 @@ def test_fusing_allgather_and_autocast(self, zero_stage, dtype):
187193
"""Test that allgather and autocast can be correctly fused with DeepCompile"""
188194
if not required_torch_version(min_version=2.6):
189195
pytest.skip("DeepCompile requires PyTorch >= v2.6")
196+
if required_torch_version(min_version=2.9):
197+
pytest.skip("DeepCompile with ZeRO stage 3 is not supported on PyTorch >= 2.9")
190198

191199
if get_accelerator().device_name() == "cpu":
192200
pytest.skip("CPU does not support this test yet")

0 commit comments

Comments
 (0)