@@ -264,12 +264,21 @@ def setup_fsdp(model: nn.Module, config: ModelConfig, parallel_dims: ParallelDim
264264
265265
266266def load_dcp_from_hf (model : nn .Module , config : ModelConfig , parallel_dims : ParallelDims ):
267- model .to_empty (device = "cuda" )
267+ device = "cpu" if config .fsdp_cpu_offload else "cuda"
268+ model .to_empty (device = device )
268269 torch .distributed .barrier ()
269270
271+ def _init_buffers_post_meta ():
272+ if isinstance (model , PreTrainedModelPrimeRL ):
273+ model .init_buffers_post_meta ()
274+ else :
275+ fix_model_post_empty (model )
276+
270277 logger = get_logger ()
271278 if config .debug .random_init :
272279 logger .warning ("Randomly initializing model. Skipping loading weights from HF." )
280+ _init_buffers_post_meta ()
281+ _move_buffers_to_cuda (model , config )
273282 return
274283
275284 if not Path (config .name ).exists ():
@@ -329,10 +338,10 @@ def load_dcp_from_hf(model: nn.Module, config: ModelConfig, parallel_dims: Paral
329338 state_dict ,
330339 storage_reader = HuggingFaceStorageReader (path = snapshot_path .as_posix ()),
331340 )
332- if isinstance ( model , PreTrainedModelPrimeRL ):
333- model . init_buffers_post_meta ()
334- else :
335- fix_model_post_empty ( model )
341+ _init_buffers_post_meta ()
342+
343+ _move_buffers_to_cuda ( model , config )
344+
336345 lora_modules = [m for m in model .modules () if hasattr (m , "_init_lora_parameters" )]
337346 if lora_modules :
338347 generator : torch .Generator | None = None
@@ -415,6 +424,15 @@ def apply_ep(model: nn.Module, parallel_dims: ParallelDims):
415424 )
416425
417426
427+ def _move_buffers_to_cuda (model : nn .Module , config : ModelConfig ) -> None :
428+ """FSDP CPU offloading only manages parameters, not buffers. Move buffers to CUDA."""
429+ if not config .fsdp_cpu_offload :
430+ return
431+ for _ , buffer in model .named_buffers ():
432+ if buffer .device .type == "cpu" :
433+ buffer .data = buffer .data .to ("cuda" )
434+
435+
418436def setup_model (
419437 config : ModelConfig , parallel_dims : ParallelDims , loading_from_checkpoint_later : bool = False
420438) -> nn .Module :
@@ -456,19 +474,25 @@ def setup_model(
456474
457475 setup_fsdp (model , config , parallel_dims )
458476
477+ if not possible_to_load_to_meta :
478+ _move_buffers_to_cuda (model , config )
479+
459480 # 2. if we can load to meta, we either:
460481 if possible_to_load_to_meta :
461482 # - load from checkpoint later if needed
462483 if loading_from_checkpoint_later :
463484 logger .warning (
464485 "Skipping loading weights. Initializing an empty model on device, loading from checkpoint later."
465486 )
466- model .to_empty (device = "cuda" )
487+ device = "cpu" if config .fsdp_cpu_offload else "cuda"
488+ model .to_empty (device = device )
467489 torch .distributed .barrier ()
468490 if isinstance (model , PreTrainedModelPrimeRL ):
469491 model .init_buffers_post_meta ()
470492 else :
471493 fix_model_post_empty (model )
494+
495+ _move_buffers_to_cuda (model , config )
472496 # - or load from HF with dcp
473497 else :
474498 load_dcp_from_hf (model , config , parallel_dims )
0 commit comments