88import logging
99import math
1010import os
11- from typing import Any
11+ from collections .abc import Mapping
12+ from dataclasses import dataclass , field , fields
1213
1314import torch
14- import torchtitan .experiments .forge .train_spec as forge_train_spec
1515from monarch .actor import current_rank , current_size , endpoint
16- from omegaconf import DictConfig , OmegaConf
17- from torch import nn
18- from torchtitan .components .loss import LossFunction
19-
20- # from torchdata.stateful_dataloader import StatefulDataLoader
21- # from torchtitan.components.checkpoint import ModelWrapper
22- from torchtitan .components .lr_scheduler import LRSchedulersContainer
23- from torchtitan .components .optimizer import OptimizersContainer
24- from torchtitan .distributed import ParallelDims , utils as dist_utils
16+ from torchtitan .config .job_config import (
17+ ActivationCheckpoint ,
18+ Checkpoint ,
19+ Comm ,
20+ Compile ,
21+ Float8 ,
22+ LRScheduler ,
23+ Model ,
24+ Optimizer ,
25+ Parallelism ,
26+ Training ,
27+ )
28+
29+ from torchtitan .distributed import utils as dist_utils
2530from torchtitan .experiments .forge .engine import ForgeEngine
2631from torchtitan .experiments .forge .job_config import ForgeJobConfig
2732
28- # from tqdm import tqdm
29-
3033from forge .controller import ForgeActor
3134
32- # from forge.interfaces import RLLoss
33-
34- # stubs for now
35- Checkpointer = Any
36- Dataloader = Any
37- MetricLogger = Any
38- Profiler = Any
39- Tokenizer = Any
40-
4135logger = logging .getLogger (__name__ )
4236logger .setLevel (logging .INFO )
4337
4438
45- class RLTrainer (ForgeActor , ForgeEngine ):
46- job_config : ForgeJobConfig
47- train_spec : forge_train_spec .ForgeTrainSpec
48- parallel_dims : ParallelDims
49- model : list [nn .Module ]
50- loss_fn : LossFunction
51- optimizer : OptimizersContainer
52- lr_scheduler : LRSchedulersContainer
53- checkpointer : Checkpointer
54- tokenizer : Tokenizer
55- train_dataloader : Dataloader
56- # val_dataloader: Dataloader
57- profiler : Profiler
58- device : torch .device
59- step : int
60-
61- def __init__ (self , config : DictConfig ):
62- job_config = ForgeJobConfig ().to_dict ()
63- # Hack to deal with literal types from titan
64- job_config = OmegaConf .merge (job_config , config )
65-
66- self .current_step = 0
67- self .num_training_steps = job_config .training .steps
68- self .gradient_accumulation_steps = 1 # Example value, adjust as needed
69- self ._rank = current_rank ().rank
70- self ._size = math .prod (current_size ().values ())
71- self ._init_dist ()
72- super ().__init__ (job_config )
73-
74- def _init_dist (self ):
75- """Initializes torch distributed.
76-
77- torchrun normally hands this, but we need to do it ourselves
39+ @dataclass
40+ class RLTrainer (ForgeActor ):
41+ model : Model = field (default_factory = Model )
42+ optimizer : Optimizer = field (default_factory = Optimizer )
43+ lr_scheduler : LRScheduler = field (default_factory = LRScheduler )
44+ training : Training = field (default_factory = Training )
45+ parallelism : Parallelism = field (default_factory = Parallelism )
46+ checkpoint : Checkpoint = field (default_factory = Checkpoint )
47+ activation_checkpoint : ActivationCheckpoint = field (
48+ default_factory = ActivationCheckpoint
49+ )
50+ compile : Compile = field (default_factory = Compile )
51+ float8 : Float8 = field (default_factory = Float8 )
52+ comm : Comm = field (default_factory = Comm )
53+
54+ def __post_init__ (self ):
55+ """Initializes config types and env variables.
56+
57+ torchrun normally hands env variables, but we need to do it ourselves
7858 in monarch for now.
7959
80- We should consider putting this into ForgeActor, but having this
81- be explicit for now.
82-
8360 """
61+ # Instantiate dict fields
62+ for f in fields (self ):
63+ attr = getattr (self , f .name )
64+ if isinstance (attr , Mapping ):
65+ setattr (self , f .name , f .type (** attr ))
66+ elif not isinstance (attr , f .type ):
67+ raise TypeError (
68+ f"{ f .name } should be a { f .type } type or a dict like object"
69+ )
70+
71+ self .current_step = 0
72+ self .num_training_steps = self .training .steps
73+ self .gradient_accumulation_steps = 1
74+ self .rank = current_rank ().rank
75+ self .size = math .prod (current_size ().values ())
76+
8477 env = {
85- "RANK" : str (self ._rank ),
86- "LOCAL_RANK" : str (self ._rank ),
87- "LOCAL_WORLD_SIZE" : str (self ._size ),
88- "GROUP_RANK" : str (self ._size ),
89- "GROUP_WORLD_SIZE" : str (self ._size ),
90- "ROLE_RANK" : str (self ._rank ),
91- "ROLE_WORLD_SIZE" : str (self ._size ),
78+ "RANK" : str (self .rank ),
79+ "LOCAL_RANK" : str (self .rank ),
80+ "LOCAL_WORLD_SIZE" : str (self .size ),
81+ "GROUP_RANK" : str (self .size ),
82+ "GROUP_WORLD_SIZE" : str (self .size ),
83+ "ROLE_RANK" : str (self .rank ),
84+ "ROLE_WORLD_SIZE" : str (self .size ),
9285 "ROLE_NAME" : "rank" ,
93- "WORLD_SIZE" : str (self ._size ),
86+ "WORLD_SIZE" : str (self .size ),
9487 "PYTORCH_CUDA_ALLOC_CONF" : "expandable_segments:True" ,
9588 }
9689 os .environ .update (env )
97- logger .info ("env: {}" .format (env ))
9890
9991 @endpoint
10092 async def setup (self ):
101- self .checkpointer .load (step = self .current_step )
102- # self.profiler = self.setup_profiler(self.train_config.profiler_config)
103- # self.logger = self.setup_logger(self.train_config.logger_config)
104- self .optimizers .zero_grad ()
105-
106- # self.pbar = tqdm(
107- # initial=0,
108- # total=self.num_training_steps,
109- # desc=f"{self.current_step}",
110- # )
111- #
93+ # TODO: update ForgeEngine to not use ForgeJobConfig
94+ engine_config = {f .name : getattr (self , f .name ) for f in fields (self )}
95+ self .engine = ForgeEngine (ForgeJobConfig (** engine_config ))
96+ self .engine .checkpointer .load (step = self .current_step )
97+ self .engine .optimizers .zero_grad ()
11298
11399 def forward_backward (
114100 self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
115101 ) -> torch .Tensor :
116- model_parts = self .model_parts
117- parallel_dims = self .parallel_dims
102+ model_parts = self .engine . model_parts
103+ parallel_dims = self .engine . parallel_dims
118104
119105 # apply context parallelism if cp is enabled
120106 # ensure CP handles the separate freqs_cis buffer for each pp stage
121107 inputs = input_dict ["tokens" ]
122108
123- if getattr (self .model_args , "use_flex_attn" , False ):
109+ if getattr (self .engine . model_args , "use_flex_attn" , False ):
124110 cp_mesh = (
125111 parallel_dims .world_mesh ["cp" ] if parallel_dims .cp_enabled else None
126112 )
127- init_attention_mask (inputs , self .tokenizer .base_tokenizer .eos_id , cp_mesh )
113+ init_attention_mask (
114+ inputs , self .engine .tokenizer .base_tokenizer .eos_id , cp_mesh
115+ )
128116
129117 optional_context_parallel_ctx = (
130118 dist_utils .create_context_parallel_ctx (
@@ -164,11 +152,11 @@ def forward_backward(
164152 # )
165153 else :
166154 # Non-PP forward / backward
167- with self .train_context (optional_context_parallel_ctx ):
155+ with self .engine . train_context (optional_context_parallel_ctx ):
168156 assert len (model_parts ) == 1
169- with self .maybe_enable_amp :
157+ with self .engine . maybe_enable_amp :
170158 pred = model_parts [0 ](inputs )
171- loss = self .loss_fn (pred , labels )
159+ loss = self .engine . loss_fn (pred , labels )
172160 # need to free to before bwd to avoid peaking memory
173161 del pred
174162 loss .backward ()
@@ -191,32 +179,92 @@ def train_step(self, batch) -> None:
191179 # TODO: convert to GRPO Loss
192180 labels = batch .pop ("labels" )
193181 loss = self .forward_backward (batch , labels )
194- # self.pbar.update(1)
195- # self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
196182
197- self .optimizers .step ()
198- self .optimizers .zero_grad ()
199- self .lr_schedulers .step ()
183+ self .engine . optimizers .step ()
184+ self .engine . optimizers .zero_grad ()
185+ self .engine . lr_schedulers .step ()
200186
201- # self.profiler.step()
202187 self .current_step += 1
203-
204- # if self.current_step % self.train_config.val_every_n_steps == 0:
205- # self.validate()
206- self .checkpointer .save (
188+ self .engine .checkpointer .save (
207189 curr_step = self .current_step ,
208190 last_step = self .current_step == self .num_training_steps ,
209191 )
210192
193+ # TODO: integrate the grpo app step with the above step
194+ # def train_step(self, self, batch: list(Episode)):
195+ # total_loss = 0.0
196+ # num_groups_processed = 0
197+ #
198+ # for episode in batch:
199+ # groups = episode.groups
200+ #
201+ # # Collect all response texts and corresponding data
202+ # response_texts = []
203+ # ref_logprobs_list = []
204+ # advantages_list = []
205+ #
206+ # for group in groups:
207+ # response_texts.append(group.response)
208+ # ref_logprobs_list.append(group.ref_logprobs)
209+ # advantages_list.append(group.advantage)
210+ #
211+ # # Tokenize all responses in batch
212+ # tokenized = self.tokenizer(
213+ # response_texts,
214+ # padding=True,
215+ # truncation=True,
216+ # return_tensors="pt",
217+ # max_length=512, # Adjust based on your needs
218+ # )
219+ #
220+ # input_ids = tokenized["input_ids"].to(self.device)
221+ # attention_mask = tokenized["attention_mask"].to(self.device)
222+ #
223+ # # Compute current policy log probabilities using the model
224+ # current_logprobs = compute_sequence_logprobs(
225+ # self.model, input_ids, attention_mask, requires_grad=True
226+ # )
227+ #
228+ # # Convert ref_logprobs and advantages to tensors
229+ # ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device)
230+ # advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to(
231+ # self.device
232+ # )
233+ #
234+ # # Compute GRPO loss components
235+ # # Ratio between current policy and reference policy
236+ # ratio = torch.exp(current_logprobs - ref_logprobs_tensor)
237+ #
238+ # # Policy gradient loss weighted by advantages
239+ # pg_loss = -torch.mean(ratio * advantages_tensor)
240+ #
241+ # # KL penalty to prevent policy from deviating too far from reference
242+ # kl_penalty = self.beta * torch.mean(
243+ # (current_logprobs - ref_logprobs_tensor) ** 2
244+ # )
245+ #
246+ # # Total GRPO loss
247+ # loss = pg_loss + kl_penalty
248+ # total_loss += loss.item()
249+ # num_groups_processed += len(groups)
250+ #
251+ # self.optimizer.zero_grad()
252+ # loss.backward()
253+ #
254+ # # Gradient clipping (optional but recommended for stability)
255+ # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
256+ #
257+ # self.optimizer.step()
258+ #
259+ # avg_loss = total_loss / len(batch) if batch else 0.0
260+ #
261+ # return {"loss": avg_loss, "groups_processed": num_groups_processed}
262+
211263 @endpoint
212264 def push_weights (self ) -> None :
213265 pass
214266
215267 @endpoint
216268 async def cleanup (self ) -> None :
217- # self.pbar.close()
218- if self .checkpointer :
219- self .checkpointer .close ()
220-
221- def __repr__ (self ) -> str :
222- return "Trainer"
269+ if self .engine .checkpointer :
270+ self .engine .checkpointer .close ()
0 commit comments