Skip to content

Commit 82a8241

Browse files
samsjasami jaghouar
andauthored
fix cpu offloading (#1636)
* fix cpu offloading * fix offlaod for moe Fix buffer init in random init path Co-authored-by: sami <[email protected]> Fix buffer device move when meta load unavailable Co-authored-by: sami <[email protected]> --------- Co-authored-by: sami jaghouar <[email protected]>
1 parent 53a5ab8 commit 82a8241

File tree

1 file changed

+30
-6
lines changed

1 file changed

+30
-6
lines changed

src/prime_rl/trainer/model.py

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,12 +264,21 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim
264264

265265

266266
def load_dcp_from_hf(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDims):
267-
model.to_empty(device="cuda")
267+
device = "cpu" if config.fsdp_cpu_offload else "cuda"
268+
model.to_empty(device=device)
268269
torch.distributed.barrier()
269270

271+
def _init_buffers_post_meta():
272+
if isinstance(model, PreTrainedModelPrimeRL):
273+
model.init_buffers_post_meta()
274+
else:
275+
fix_model_post_empty(model)
276+
270277
logger = get_logger()
271278
if config.debug.random_init:
272279
logger.warning("Randomly initializing model. Skipping loading weights from HF.")
280+
_init_buffers_post_meta()
281+
_move_buffers_to_cuda(model, config)
273282
return
274283

275284
if not Path(config.name).exists():
@@ -329,10 +338,10 @@ def load_dcp_from_hf(model: nn.Module, config: ModelConfig, parallel_dims: Paral
329338
state_dict,
330339
storage_reader=HuggingFaceStorageReader(path=snapshot_path.as_posix()),
331340
)
332-
if isinstance(model, PreTrainedModelPrimeRL):
333-
model.init_buffers_post_meta()
334-
else:
335-
fix_model_post_empty(model)
341+
_init_buffers_post_meta()
342+
343+
_move_buffers_to_cuda(model, config)
344+
336345
lora_modules = [m for m in model.modules() if hasattr(m, "_init_lora_parameters")]
337346
if lora_modules:
338347
generator: torch.Generator | None = None
@@ -415,6 +424,15 @@ def apply_ep(model: nn.Module, parallel_dims: ParallelDims):
415424
)
416425

417426

427+
def _move_buffers_to_cuda(model: nn.Module, config: ModelConfig) -> None:
428+
"""FSDP CPU offloading only manages parameters, not buffers. Move buffers to CUDA."""
429+
if not config.fsdp_cpu_offload:
430+
return
431+
for _, buffer in model.named_buffers():
432+
if buffer.device.type == "cpu":
433+
buffer.data = buffer.data.to("cuda")
434+
435+
418436
def setup_model(
419437
config: ModelConfig, parallel_dims: ParallelDims, loading_from_checkpoint_later: bool = False
420438
) -> nn.Module:
@@ -456,19 +474,25 @@ def setup_model(
456474

457475
setup_fsdp(model, config, parallel_dims)
458476

477+
if not possible_to_load_to_meta:
478+
_move_buffers_to_cuda(model, config)
479+
459480
# 2. if we can load to meta, we either:
460481
if possible_to_load_to_meta:
461482
# - load from checkpoint later if needed
462483
if loading_from_checkpoint_later:
463484
logger.warning(
464485
"Skipping loading weights. Initializing an empty model on device, loading from checkpoint later."
465486
)
466-
model.to_empty(device="cuda")
487+
device = "cpu" if config.fsdp_cpu_offload else "cuda"
488+
model.to_empty(device=device)
467489
torch.distributed.barrier()
468490
if isinstance(model, PreTrainedModelPrimeRL):
469491
model.init_buffers_post_meta()
470492
else:
471493
fix_model_post_empty(model)
494+
495+
_move_buffers_to_cuda(model, config)
472496
# - or load from HF with dcp
473497
else:
474498
load_dcp_from_hf(model, config, parallel_dims)

0 commit comments

Comments
 (0)