@@ -268,13 +268,6 @@ def load_dcp_from_hf(model: nn.Module, config: ModelConfig, parallel_dims: Paral
268268 model .to_empty (device = device )
269269 torch .distributed .barrier ()
270270
271- def _move_buffers_to_cuda ():
272- """FSDP CPU offloading only manages parameters, not buffers. Move buffers to CUDA."""
273- if config .fsdp_cpu_offload :
274- for name , buffer in model .named_buffers ():
275- if buffer .device .type == "cpu" :
276- buffer .data = buffer .data .to ("cuda" )
277-
278271 def _init_buffers_post_meta ():
279272 if isinstance (model , PreTrainedModelPrimeRL ):
280273 model .init_buffers_post_meta ()
@@ -285,7 +278,7 @@ def _init_buffers_post_meta():
285278 if config .debug .random_init :
286279 logger .warning ("Randomly initializing model. Skipping loading weights from HF." )
287280 _init_buffers_post_meta ()
288- _move_buffers_to_cuda ()
281+ _move_buffers_to_cuda (model , config )
289282 return
290283
291284 if not Path (config .name ).exists ():
@@ -347,7 +340,7 @@ def _init_buffers_post_meta():
347340 )
348341 _init_buffers_post_meta ()
349342
350- _move_buffers_to_cuda ()
343+ _move_buffers_to_cuda (model , config )
351344
352345 lora_modules = [m for m in model .modules () if hasattr (m , "_init_lora_parameters" )]
353346 if lora_modules :
@@ -431,6 +424,15 @@ def apply_ep(model: nn.Module, parallel_dims: ParallelDims):
431424 )
432425
433426
427+ def _move_buffers_to_cuda (model : nn .Module , config : ModelConfig ) -> None :
428+ """FSDP CPU offloading only manages parameters, not buffers. Move buffers to CUDA."""
429+ if not config .fsdp_cpu_offload :
430+ return
431+ for _ , buffer in model .named_buffers ():
432+ if buffer .device .type == "cpu" :
433+ buffer .data = buffer .data .to ("cuda" )
434+
435+
434436def setup_model (
435437 config : ModelConfig , parallel_dims : ParallelDims , loading_from_checkpoint_later : bool = False
436438) -> nn .Module :
@@ -472,6 +474,9 @@ def setup_model(
472474
473475 setup_fsdp (model , config , parallel_dims )
474476
477+ if not possible_to_load_to_meta :
478+ _move_buffers_to_cuda (model , config )
479+
475480 # 2. if we can load to meta, we either:
476481 if possible_to_load_to_meta :
477482 # - load from checkpoint later if needed
@@ -487,12 +492,7 @@ def setup_model(
487492 else :
488493 fix_model_post_empty (model )
489494
490- # FSDP CPU offloading only manages parameters, not buffers.
491- # Buffers must be on CUDA for the forward pass.
492- if config .fsdp_cpu_offload :
493- for name , buffer in model .named_buffers ():
494- if buffer .device .type == "cpu" :
495- buffer .data = buffer .data .to ("cuda" )
495+ _move_buffers_to_cuda (model , config )
496496 # - or load from HF with dcp
497497 else :
498498 load_dcp_from_hf (model , config , parallel_dims )
0 commit comments