Skip to content

Commit f566b5c

Browse files
cursoragentsami
andcommitted
Fix buffer device move when meta load unavailable
Co-authored-by: sami <[email protected]>
1 parent e9c264c commit f566b5c

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/prime_rl/trainer/model.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -268,13 +268,6 @@ def load_dcp_from_hf(model: nn.Module, config: ModelConfig, parallel_dims: Paral
268268
model.to_empty(device=device)
269269
torch.distributed.barrier()
270270

271-
def _move_buffers_to_cuda():
272-
"""FSDP CPU offloading only manages parameters, not buffers. Move buffers to CUDA."""
273-
if config.fsdp_cpu_offload:
274-
for name, buffer in model.named_buffers():
275-
if buffer.device.type == "cpu":
276-
buffer.data = buffer.data.to("cuda")
277-
278271
def _init_buffers_post_meta():
279272
if isinstance(model, PreTrainedModelPrimeRL):
280273
model.init_buffers_post_meta()
@@ -285,7 +278,7 @@ def _init_buffers_post_meta():
285278
if config.debug.random_init:
286279
logger.warning("Randomly initializing model. Skipping loading weights from HF.")
287280
_init_buffers_post_meta()
288-
_move_buffers_to_cuda()
281+
_move_buffers_to_cuda(model, config)
289282
return
290283

291284
if not Path(config.name).exists():
@@ -347,7 +340,7 @@ def _init_buffers_post_meta():
347340
)
348341
_init_buffers_post_meta()
349342

350-
_move_buffers_to_cuda()
343+
_move_buffers_to_cuda(model, config)
351344

352345
lora_modules = [m for m in model.modules() if hasattr(m, "_init_lora_parameters")]
353346
if lora_modules:
@@ -431,6 +424,15 @@ def apply_ep(model: nn.Module, parallel_dims: ParallelDims):
431424
)
432425

433426

427+
def _move_buffers_to_cuda(model: nn.Module, config: ModelConfig) -> None:
428+
"""FSDP CPU offloading only manages parameters, not buffers. Move buffers to CUDA."""
429+
if not config.fsdp_cpu_offload:
430+
return
431+
for _, buffer in model.named_buffers():
432+
if buffer.device.type == "cpu":
433+
buffer.data = buffer.data.to("cuda")
434+
435+
434436
def setup_model(
435437
config: ModelConfig, parallel_dims: ParallelDims, loading_from_checkpoint_later: bool = False
436438
) -> nn.Module:
@@ -472,6 +474,9 @@ def setup_model(
472474

473475
setup_fsdp(model, config, parallel_dims)
474476

477+
if not possible_to_load_to_meta:
478+
_move_buffers_to_cuda(model, config)
479+
475480
# 2. if we can load to meta, we either:
476481
if possible_to_load_to_meta:
477482
# - load from checkpoint later if needed
@@ -487,12 +492,7 @@ def setup_model(
487492
else:
488493
fix_model_post_empty(model)
489494

490-
# FSDP CPU offloading only manages parameters, not buffers.
491-
# Buffers must be on CUDA for the forward pass.
492-
if config.fsdp_cpu_offload:
493-
for name, buffer in model.named_buffers():
494-
if buffer.device.type == "cpu":
495-
buffer.data = buffer.data.to("cuda")
495+
_move_buffers_to_cuda(model, config)
496496
# - or load from HF with dcp
497497
else:
498498
load_dcp_from_hf(model, config, parallel_dims)

0 commit comments

Comments
 (0)