2121from monarch .actor import current_rank , current_size , endpoint
2222from torch import Tensor
2323from torch .distributed .checkpoint ._nested_dict import flatten_state_dict
24- from torchstore .state_dict_utils import DELIM
2524from torchtitan .config .job_config import (
2625 ActivationCheckpoint ,
2726 Checkpoint ,
@@ -114,8 +113,6 @@ class RLTrainer(ForgeActor):
114113 state_dict_key : str = "model_state_dict"
115114 use_dcp : bool = True
116115 dcp_path : str = "forge_dcp_tmp"
117- vllm_tp_DEPRECATED : int = 1 # noqa: N815
118- use_vllm_builtin_load : bool = True
119116
120117 def __post_init__ (self ):
121118 """Initializes config types and env variables.
@@ -159,6 +156,8 @@ def __post_init__(self):
159156 "PYTORCH_CUDA_ALLOC_CONF" : "expandable_segments:True" ,
160157 }
161158 os .environ .update (env )
159+ logger .info ("Compiling loss" )
160+ self .loss = torch .compile (self .loss )
162161
163162 @endpoint
164163 async def setup (self ):
@@ -168,9 +167,7 @@ async def setup(self):
168167 "loss" ,
169168 "state_dict_key" ,
170169 "use_dcp" ,
171- "use_vllm_builtin_load" ,
172170 "dcp_path" ,
173- "vllm_tp_DEPRECATED" ,
174171 }:
175172 engine_config .pop (key ) # Not part of job config
176173 self .engine = ForgeEngine (ForgeJobConfig (** engine_config ))
@@ -302,76 +299,12 @@ async def train_step(
302299 t .stop ()
303300 return loss
304301
305- @endpoint
306- async def push_weights_DEPRECATED ( # noqa: N802
307- self , policy_version : int , vllm_tp_DEPRECATED : int = 1
308- ) -> None : # noqa: N802
309- """[Deprecated] This method pushes weights to torchstore in the vllm format,
310- which is buggy and not scalable to other models.
311- Deprecated in favor of push_weights."""
312- return await self ._push_weights_DEPRECATED (policy_version , vllm_tp_DEPRECATED )
313-
314- async def _push_weights_DEPRECATED ( # noqa: N802
315- self , policy_version : int , vllm_tp_DEPRECATED : int
316- ) -> None : # noqa: N802
317- # Save to torchstore. Hacking in to the Checkpointer's prepped state-dict for now.
318- # TODO:
319- # 1. Checkpoint invokes state-dict flattening during dcp_save for [MODEL].
320- # May need to replicate the same in this code path.
321- # 2. Unify CheckpointManager and TorchStore weights save control path.
322- if "model" not in self .engine .checkpointer .states :
323- raise RuntimeError ("Model state not found in checkpointer state" )
324-
325- sd = self .engine .checkpointer .states ["model" ].state_dict ()
326- flattened_state_dict , _ = flatten_state_dict (sd )
327-
328- if self .engine .checkpointer .sd_adapter is None :
329- raise RuntimeError (
330- "Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
331- )
332- hf_state_dict = self .engine .checkpointer .sd_adapter .to_hf (flattened_state_dict )
333-
334- # TODO: Figure out how to gracefully handle which model to-vLLM conversion is needed
335- vllm_ready_hf_sd = _qwen3_hf_to_vllm (
336- sd = hf_state_dict ,
337- num_layers = self .engine .model_args .n_layers ,
338- vllm_tp = vllm_tp_DEPRECATED ,
339- )
340-
341- key = f"{ self .state_dict_key } { DELIM } { policy_version } "
342- if self .use_dcp :
343- # TODO - DCP should probably be being saved to NFS explicitly?
344- # Right now it will only save everything locally
345- storage_writer = torch .distributed .checkpoint .FileSystemWriter (
346- key , single_file_per_rank = False , thread_count = 8
347- )
348- metadata = dcp .save (
349- storage_writer = storage_writer , state_dict = vllm_ready_hf_sd
350- )
351- await ts .put (key , metadata )
352-
353- # Delete old weight versions if they exist
354- if self .rank == 0 :
355- cleanup_old_weight_versions (
356- state_dict_key = self .state_dict_key ,
357- delim = DELIM ,
358- current_policy_version = policy_version ,
359- )
360- else :
361- await ts .put_state_dict (vllm_ready_hf_sd , key )
362-
363302 @endpoint
364303 async def push_weights (self , policy_version : int ) -> None :
365304 """Push weights to torchstore in HF format."""
366305 t = Tracer ("rl_trainer_perf/push_weights" , timer = "gpu" , track_memory = True )
367306 t .start ()
368307 logger .info (f"Pushing weights for policy version { policy_version } " )
369- if not self .use_vllm_builtin_load :
370- result = await self ._push_weights_DEPRECATED (
371- policy_version , self .vllm_tp_DEPRECATED
372- )
373- t .step ("push_weights_DEPRECATED" )
374- return result
375308
376309 start_time = time .perf_counter ()
377310 if "model" not in self .engine .checkpointer .states :
0 commit comments