Skip to content

Commit 2a1fe91

Browse files
authored
[docker] trim megatron and sglang patch (#552)
1 parent a5cadbd commit 2a1fe91

File tree

10 files changed

+439
-687
lines changed

10 files changed

+439
-687
lines changed

docker/patch/latest/megatron.patch

Lines changed: 14 additions & 314 deletions
Large diffs are not rendered by default.

docker/patch/latest/sglang.patch

Lines changed: 21 additions & 308 deletions
Large diffs are not rendered by default.

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
accelerate
22
datasets
3+
deepspeed
34
httpx[http2]
45
mcp[cli]
56
pillow
@@ -10,4 +11,3 @@ tensorboard
1011
torch
1112
transformers
1213
wandb
13-
deepspeed

slime/backends/fsdp_utils/actor.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,16 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
8686

8787
if args.optimizer == "deepspeed_cpu_adam":
8888
optimizer_config = {
89-
'lr': args.lr,
90-
'betas': (args.adam_beta1, args.adam_beta2),
91-
'eps': args.adam_eps,
92-
'weight_decay': args.weight_decay,
93-
'adamw_mode': True, # Use AdamW mode (decoupled weight decay)
94-
'fp32_optimizer_states': True, # Keep optimizer states in FP32
89+
"lr": args.lr,
90+
"betas": (args.adam_beta1, args.adam_beta2),
91+
"eps": args.adam_eps,
92+
"weight_decay": args.weight_decay,
93+
"adamw_mode": True, # Use AdamW mode (decoupled weight decay)
94+
"fp32_optimizer_states": True, # Keep optimizer states in FP32
9595
}
96-
96+
9797
self.optimizer = FSDPCPUAdamWrapper(optimizer_config, self.model)
98-
98+
9999
elif args.optimizer == "adam":
100100
self.optimizer = torch.optim.AdamW(
101101
self.model.parameters(),
@@ -104,9 +104,11 @@ def init(self, args: Namespace, role: str, wandb_run_id: str, with_ref: bool = F
104104
eps=args.adam_eps,
105105
weight_decay=args.weight_decay,
106106
)
107-
107+
108108
else:
109-
raise ValueError(f"Unsupported optimizer: {args.optimizer}. Supported options: 'adam', 'deepspeed_cpu_adam'")
109+
raise ValueError(
110+
f"Unsupported optimizer: {args.optimizer}. Supported options: 'adam', 'deepspeed_cpu_adam'"
111+
)
110112

111113
# TODO: load
112114

@@ -149,7 +151,7 @@ def sleep(self, tags: str | Iterable[str] | None) -> None:
149151

150152
if isinstance(tags, str):
151153
tags = (tags,)
152-
154+
153155
if torch_memory_saver is not None:
154156
torch_memory_saver.pause()
155157

@@ -164,10 +166,10 @@ def wake_up(self, tags: str | Iterable[str] | None) -> None:
164166
"""
165167
if not getattr(self.args, "offload", False):
166168
return
167-
169+
168170
if isinstance(tags, str):
169171
tags = (tags,)
170-
172+
171173
if torch_memory_saver is not None:
172174
torch_memory_saver.resume()
173175

@@ -555,7 +557,6 @@ def update_weights(self) -> None: # type: ignore[override]
555557
self.weight_updater.connect_rollout_engines(rollout_engines, rollout_engine_lock)
556558
dist.barrier(group=get_gloo_group())
557559

558-
559560
with torch_memory_saver.disable() if self.args.offload and not torch.version.hip else nullcontext():
560561
self.weight_updater.update_weights()
561562

Lines changed: 36 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,63 @@
1-
from typing import Dict, List, Any
1+
from typing import Any, Dict, List
22

33
import torch
4-
import torch.distributed as dist
54
import torch.nn as nn
6-
from torch.distributed.tensor import DTensor
75
from deepspeed.ops.adam import DeepSpeedCPUAdam
6+
from torch.distributed.tensor import DTensor
87

98

109
class FSDPCPUAdamWrapper:
1110
"""
1211
Wrapper for DeepSpeedCPUAdam to work with FSDP models where parameters are on GPU.
13-
12+
1413
DeepSpeedCPUAdam requires both parameters and gradients to be on CPU. This wrapper:
1514
1. Maintains CPU shadow copies of GPU parameters (contiguous, proper dtype)
1615
2. Copies gradients from GPU to CPU before optimizer step (contiguous)
1716
3. Runs optimizer update on CPU
1817
4. Copies updated parameters back to GPU
19-
18+
2019
Following the parameter copy pattern from update_weight_utils.py
2120
"""
22-
21+
2322
def __init__(self, optimizer_config: Dict[str, Any], model: nn.Module) -> None:
2423
self.model: nn.Module = model
2524
self.gpu_params: List[nn.Parameter] = list(model.parameters())
2625
self.optimizer_config: Dict[str, Any] = optimizer_config
2726
self.cpu_params: List[torch.Tensor] = []
2827
self.cpu_optimizer: DeepSpeedCPUAdam
29-
28+
3029
# Create CPU shadow copies of parameters using the pattern from update_weight_utils.py
3130
# Store only the LOCAL SHARD for each rank, not the full tensor
3231
for gpu_param in self.gpu_params:
3332
param_data = gpu_param.detach()
3433
if isinstance(param_data, DTensor):
3534
param_data = param_data.to_local()
36-
37-
cpu_param = param_data.contiguous().to(device='cpu', dtype=torch.float32, non_blocking=True)
35+
36+
cpu_param = param_data.contiguous().to(device="cpu", dtype=torch.float32, non_blocking=True)
3837
cpu_param.requires_grad_(True)
39-
38+
4039
assert cpu_param.is_contiguous(), f"CPU param must be contiguous for AVX"
4140
assert cpu_param.dtype == torch.float32, f"CPU param must be FP32 for DeepSpeed"
42-
41+
4342
self.cpu_params.append(cpu_param)
44-
43+
4544
torch.cuda.synchronize()
4645

4746
self.cpu_optimizer = DeepSpeedCPUAdam(
4847
self.cpu_params,
49-
lr=self.optimizer_config['lr'],
50-
betas=self.optimizer_config['betas'],
51-
eps=self.optimizer_config['eps'],
52-
weight_decay=self.optimizer_config['weight_decay'],
53-
adamw_mode=self.optimizer_config['adamw_mode'],
54-
fp32_optimizer_states=self.optimizer_config['fp32_optimizer_states'],
48+
lr=self.optimizer_config["lr"],
49+
betas=self.optimizer_config["betas"],
50+
eps=self.optimizer_config["eps"],
51+
weight_decay=self.optimizer_config["weight_decay"],
52+
adamw_mode=self.optimizer_config["adamw_mode"],
53+
fp32_optimizer_states=self.optimizer_config["fp32_optimizer_states"],
5554
)
56-
55+
5756
self.param_groups = self.cpu_optimizer.param_groups
58-
57+
5958
def zero_grad(self, set_to_none: bool = True) -> None:
6059
"""Zero gradients on GPU parameters.
61-
60+
6261
Args:
6362
set_to_none: If True, set gradients to None; otherwise zero them.
6463
"""
@@ -67,15 +66,15 @@ def zero_grad(self, set_to_none: bool = True) -> None:
6766
param.grad = None
6867
elif param.grad is not None:
6968
param.grad.zero_()
70-
69+
7170
def step(self) -> None:
7271
"""Perform optimizer step.
73-
72+
7473
Steps:
7574
1. Copy gradients from GPU to CPU (handling DTensor, ensuring contiguous FP32)
7675
2. Run optimizer update on CPU
7776
3. Copy updated parameters back to GPU
78-
77+
7978
Uses the same .to() pattern as update_weight_utils.py for proper memory layout.
8079
"""
8180
# Copy gradients from GPU to CPU - handle DTensor and ensure FP32 for DeepSpeed AVX
@@ -85,29 +84,31 @@ def step(self) -> None:
8584
grad_data = gpu_param.grad.detach()
8685
if isinstance(grad_data, DTensor):
8786
grad_data = grad_data.to_local()
88-
87+
8988
# DeepSpeed's AVX operations expect FP32 gradients to match FP32 params
90-
cpu_grad = grad_data.contiguous().to(device='cpu', dtype=torch.float32, non_blocking=True)
91-
89+
cpu_grad = grad_data.contiguous().to(device="cpu", dtype=torch.float32, non_blocking=True)
90+
9291
# Verify gradient properties for DeepSpeed AVX
9392
assert cpu_grad.is_contiguous(), "CPU gradient must be contiguous for AVX"
9493
assert cpu_grad.dtype == torch.float32, "CPU gradient must be FP32 for DeepSpeed"
95-
94+
9695
cpu_param.grad = cpu_grad
9796
else:
9897
cpu_param.grad = None
99-
98+
10099
torch.cuda.synchronize()
101-
100+
102101
# Run optimizer step on CPU
103102
self.cpu_optimizer.step()
104-
103+
105104
for gpu_param, cpu_param in zip(self.gpu_params, self.cpu_params):
106-
updated_param = cpu_param.data.to(device=torch.cuda.current_device(), dtype=gpu_param.dtype, non_blocking=True)
107-
105+
updated_param = cpu_param.data.to(
106+
device=torch.cuda.current_device(), dtype=gpu_param.dtype, non_blocking=True
107+
)
108+
108109
if isinstance(gpu_param.data, DTensor):
109110
gpu_param.data.to_local().copy_(updated_param, non_blocking=True)
110111
else:
111112
gpu_param.data.copy_(updated_param, non_blocking=True)
112-
113-
torch.cuda.synchronize()
113+
114+
torch.cuda.synchronize()

slime/backends/fsdp_utils/update_weight_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -152,14 +152,14 @@ def update_weights(self) -> None:
152152

153153
if self.full_params:
154154
print("Using FULL_STATE_DICT path with loading from CPU storage")
155-
155+
156156
# Load all parameters from CPU storage to GPU in one go
157157
# This is more memory intensive but faster than bucket-based approach
158158
named_tensors = []
159159
for name, cpu_param in self.weights["actor"].items():
160160
gpu_param = cpu_param.to(device=torch.cuda.current_device(), non_blocking=True)
161161
named_tensors.append((name, gpu_param))
162-
162+
163163
torch.cuda.synchronize()
164164

165165
if use_flattened_tensor_bucket:
@@ -359,11 +359,11 @@ def update_weights(self) -> None:
359359
cpu_param = self.weights["actor"][name]
360360
gpu_param = cpu_param.to(device=torch.cuda.current_device(), dtype=torch.bfloat16, non_blocking=True)
361361
torch.cuda.synchronize()
362-
362+
363363
# Broadcast this single parameter
364364
single_param_dict = {name: gpu_param}
365365
self.request_update_params(single_param_dict)
366-
366+
367367
del gpu_param
368368
clear_memory()
369369

slime/backends/megatron_utils/actor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from slime.utils.distributed_utils import get_gloo_group, init_process_group
2020
from slime.utils.memory_utils import clear_memory, print_memory
2121
from slime.utils.ray_utils import Box
22+
from slime.utils.reloadable_process_group import destroy_process_groups, monkey_patch_torch_dist, reload_process_groups
23+
from slime.utils.routing_replay import RoutingReplay
2224
from slime.utils.timer import Timer, timer
2325
from slime.utils.types import RolloutBatch
2426
from slime.utils.wandb_utils import init_wandb_secondary
@@ -40,6 +42,8 @@ def init(
4042
wandb_run_id: str,
4143
with_ref: bool = False,
4244
) -> Optional[int]:
45+
monkey_patch_torch_dist()
46+
4347
super().init(args, role, wandb_run_id, with_ref)
4448

4549
init(args)
@@ -158,8 +162,7 @@ def sleep(self, tags: Union[str, Tuple[str, ...]]) -> None:
158162

159163
clear_memory()
160164
print_memory("before offload model")
161-
if hasattr(mpu, "destroy_process_groups"):
162-
mpu.destroy_process_groups()
165+
destroy_process_groups()
163166

164167
torch_memory_saver.pause()
165168

@@ -184,8 +187,7 @@ def wake_up(self, tags: Union[str, Tuple[str, ...]]) -> None:
184187
torch_memory_saver.resume()
185188

186189
clear_memory()
187-
if hasattr(mpu, "reload_process_groups"):
188-
mpu.reload_process_groups()
190+
reload_process_groups()
189191
print_memory("after wake_up model")
190192

191193
def _get_rollout_data(self, rollout_data_ref: Box) -> RolloutBatch:
@@ -375,8 +377,6 @@ def train_actor(self, rollout_id: int, rollout_data: RolloutBatch) -> None:
375377
)
376378

377379
if self.args.use_routing_replay:
378-
from megatron.core.transformer.moe.moe_utils import RoutingReplay
379-
380380
RoutingReplay.clear_all()
381381

382382
# update the cpu actor weight to the latest model
@@ -407,8 +407,8 @@ def update_weights(self) -> None:
407407
if self.args.debug_train_only or self.args.debug_rollout_only:
408408
return
409409

410-
if self.args.offload and hasattr(mpu, "reload_process_groups"):
411-
mpu.reload_process_groups()
410+
if self.args.offload:
411+
reload_process_groups()
412412

413413
rollout_engines, rollout_engine_lock, num_new_engines = ray.get(
414414
self.rollout_manager.get_rollout_engines_and_lock.remote()
@@ -434,8 +434,8 @@ def update_weights(self) -> None:
434434
else:
435435
self.update_cpu_params_dict(self.weights["old_actor"])
436436

437-
if self.args.offload and hasattr(mpu, "destroy_process_groups"):
438-
mpu.destroy_process_groups()
437+
if self.args.offload:
438+
destroy_process_groups()
439439

440440
def load_other_checkpoint(self, model_tag: str, path: str) -> None:
441441
old_args = self.args.load, self.args.no_load_optim, self.args.no_load_rng, self.args.finetune

slime/utils/data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
# TODO: don't read the whole file into memory.
1616
def read_file(path):
1717
if path.endswith(".jsonl"):
18-
df = pd.read_json(path, lines=True, dtype={'label': str})
18+
df = pd.read_json(path, lines=True, dtype={"label": str})
1919
elif path.endswith(".parquet"):
2020
df = pd.read_parquet(path, dtype_backend="pyarrow")
2121
else:

0 commit comments

Comments
 (0)