Skip to content

Commit 4ca0685

Browse files
committed
Merge remote-tracking branch 'origin/main' into ref-actor
2 parents c3f2c8c + a13ccbf commit 4ca0685

File tree

4 files changed

+154
-103
lines changed

4 files changed

+154
-103
lines changed

apps/rl/llama3_8b.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ trainer:
3434
seq_len: 2048
3535
max_norm: 1.0
3636
steps: 5
37-
compile: false
3837
dataset: "c4"
3938

39+
compile:
40+
enable: false
41+
4042
parallelism:
4143
data_parallel_replicate_degree: 1
4244
data_parallel_shard_degree: -1

apps/rl/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ async def run(cfg: DictConfig):
3030
spawn_actors(
3131
name="trainer",
3232
actor_cls=RLTrainer,
33-
cfg={"config": cfg.trainer},
33+
cfg=cfg.trainer,
3434
processes=cfg.trainer.pop("processes"),
3535
set_address=True,
3636
),

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ dependencies = [
2222
"tokenizers",
2323
# Miscellaneous
2424
"omegaconf",
25+
"wandb",
2526
]
2627
dynamic = ["version"]
2728

src/forge/actors/trainer.py

Lines changed: 149 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -8,123 +8,111 @@
88
import logging
99
import math
1010
import os
11-
from typing import Any
11+
from collections.abc import Mapping
12+
from dataclasses import dataclass, field, fields
1213

1314
import torch
14-
import torchtitan.experiments.forge.train_spec as forge_train_spec
1515
from monarch.actor import current_rank, current_size, endpoint
16-
from omegaconf import DictConfig, OmegaConf
17-
from torch import nn
18-
from torchtitan.components.loss import LossFunction
19-
20-
# from torchdata.stateful_dataloader import StatefulDataLoader
21-
# from torchtitan.components.checkpoint import ModelWrapper
22-
from torchtitan.components.lr_scheduler import LRSchedulersContainer
23-
from torchtitan.components.optimizer import OptimizersContainer
24-
from torchtitan.distributed import ParallelDims, utils as dist_utils
16+
from torchtitan.config.job_config import (
17+
ActivationCheckpoint,
18+
Checkpoint,
19+
Comm,
20+
Compile,
21+
Float8,
22+
LRScheduler,
23+
Model,
24+
Optimizer,
25+
Parallelism,
26+
Training,
27+
)
28+
29+
from torchtitan.distributed import utils as dist_utils
2530
from torchtitan.experiments.forge.engine import ForgeEngine
2631
from torchtitan.experiments.forge.job_config import ForgeJobConfig
2732

28-
# from tqdm import tqdm
29-
3033
from forge.controller import ForgeActor
3134

32-
# from forge.interfaces import RLLoss
33-
34-
# stubs for now
35-
Checkpointer = Any
36-
Dataloader = Any
37-
MetricLogger = Any
38-
Profiler = Any
39-
Tokenizer = Any
40-
4135
logger = logging.getLogger(__name__)
4236
logger.setLevel(logging.INFO)
4337

4438

45-
class RLTrainer(ForgeActor, ForgeEngine):
46-
job_config: ForgeJobConfig
47-
train_spec: forge_train_spec.ForgeTrainSpec
48-
parallel_dims: ParallelDims
49-
model: list[nn.Module]
50-
loss_fn: LossFunction
51-
optimizer: OptimizersContainer
52-
lr_scheduler: LRSchedulersContainer
53-
checkpointer: Checkpointer
54-
tokenizer: Tokenizer
55-
train_dataloader: Dataloader
56-
# val_dataloader: Dataloader
57-
profiler: Profiler
58-
device: torch.device
59-
step: int
60-
61-
def __init__(self, config: DictConfig):
62-
job_config = ForgeJobConfig().to_dict()
63-
# Hack to deal with literal types from titan
64-
job_config = OmegaConf.merge(job_config, config)
65-
66-
self.current_step = 0
67-
self.num_training_steps = job_config.training.steps
68-
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
69-
self._rank = current_rank().rank
70-
self._size = math.prod(current_size().values())
71-
self._init_dist()
72-
super().__init__(job_config)
73-
74-
def _init_dist(self):
75-
"""Initializes torch distributed.
76-
77-
torchrun normally hands this, but we need to do it ourselves
39+
@dataclass
40+
class RLTrainer(ForgeActor):
41+
model: Model = field(default_factory=Model)
42+
optimizer: Optimizer = field(default_factory=Optimizer)
43+
lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
44+
training: Training = field(default_factory=Training)
45+
parallelism: Parallelism = field(default_factory=Parallelism)
46+
checkpoint: Checkpoint = field(default_factory=Checkpoint)
47+
activation_checkpoint: ActivationCheckpoint = field(
48+
default_factory=ActivationCheckpoint
49+
)
50+
compile: Compile = field(default_factory=Compile)
51+
float8: Float8 = field(default_factory=Float8)
52+
comm: Comm = field(default_factory=Comm)
53+
54+
def __post_init__(self):
55+
"""Initializes config types and env variables.
56+
57+
torchrun normally hands env variables, but we need to do it ourselves
7858
in monarch for now.
7959
80-
We should consider putting this into ForgeActor, but having this
81-
be explicit for now.
82-
8360
"""
61+
# Instantiate dict fields
62+
for f in fields(self):
63+
attr = getattr(self, f.name)
64+
if isinstance(attr, Mapping):
65+
setattr(self, f.name, f.type(**attr))
66+
elif not isinstance(attr, f.type):
67+
raise TypeError(
68+
f"{f.name} should be a {f.type} type or a dict like object"
69+
)
70+
71+
self.current_step = 0
72+
self.num_training_steps = self.training.steps
73+
self.gradient_accumulation_steps = 1
74+
self.rank = current_rank().rank
75+
self.size = math.prod(current_size().values())
76+
8477
env = {
85-
"RANK": str(self._rank),
86-
"LOCAL_RANK": str(self._rank),
87-
"LOCAL_WORLD_SIZE": str(self._size),
88-
"GROUP_RANK": str(self._size),
89-
"GROUP_WORLD_SIZE": str(self._size),
90-
"ROLE_RANK": str(self._rank),
91-
"ROLE_WORLD_SIZE": str(self._size),
78+
"RANK": str(self.rank),
79+
"LOCAL_RANK": str(self.rank),
80+
"LOCAL_WORLD_SIZE": str(self.size),
81+
"GROUP_RANK": str(self.size),
82+
"GROUP_WORLD_SIZE": str(self.size),
83+
"ROLE_RANK": str(self.rank),
84+
"ROLE_WORLD_SIZE": str(self.size),
9285
"ROLE_NAME": "rank",
93-
"WORLD_SIZE": str(self._size),
86+
"WORLD_SIZE": str(self.size),
9487
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
9588
}
9689
os.environ.update(env)
97-
logger.info("env: {}".format(env))
9890

9991
@endpoint
10092
async def setup(self):
101-
self.checkpointer.load(step=self.current_step)
102-
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
103-
# self.logger = self.setup_logger(self.train_config.logger_config)
104-
self.optimizers.zero_grad()
105-
106-
# self.pbar = tqdm(
107-
# initial=0,
108-
# total=self.num_training_steps,
109-
# desc=f"{self.current_step}",
110-
# )
111-
#
93+
# TODO: update ForgeEngine to not use ForgeJobConfig
94+
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
95+
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
96+
self.engine.checkpointer.load(step=self.current_step)
97+
self.engine.optimizers.zero_grad()
11298

11399
def forward_backward(
114100
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
115101
) -> torch.Tensor:
116-
model_parts = self.model_parts
117-
parallel_dims = self.parallel_dims
102+
model_parts = self.engine.model_parts
103+
parallel_dims = self.engine.parallel_dims
118104

119105
# apply context parallelism if cp is enabled
120106
# ensure CP handles the separate freqs_cis buffer for each pp stage
121107
inputs = input_dict["tokens"]
122108

123-
if getattr(self.model_args, "use_flex_attn", False):
109+
if getattr(self.engine.model_args, "use_flex_attn", False):
124110
cp_mesh = (
125111
parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
126112
)
127-
init_attention_mask(inputs, self.tokenizer.base_tokenizer.eos_id, cp_mesh)
113+
init_attention_mask(
114+
inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh
115+
)
128116

129117
optional_context_parallel_ctx = (
130118
dist_utils.create_context_parallel_ctx(
@@ -164,11 +152,11 @@ def forward_backward(
164152
# )
165153
else:
166154
# Non-PP forward / backward
167-
with self.train_context(optional_context_parallel_ctx):
155+
with self.engine.train_context(optional_context_parallel_ctx):
168156
assert len(model_parts) == 1
169-
with self.maybe_enable_amp:
157+
with self.engine.maybe_enable_amp:
170158
pred = model_parts[0](inputs)
171-
loss = self.loss_fn(pred, labels)
159+
loss = self.engine.loss_fn(pred, labels)
172160
# need to free to before bwd to avoid peaking memory
173161
del pred
174162
loss.backward()
@@ -191,32 +179,92 @@ def train_step(self, batch) -> None:
191179
# TODO: convert to GRPO Loss
192180
labels = batch.pop("labels")
193181
loss = self.forward_backward(batch, labels)
194-
# self.pbar.update(1)
195-
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
196182

197-
self.optimizers.step()
198-
self.optimizers.zero_grad()
199-
self.lr_schedulers.step()
183+
self.engine.optimizers.step()
184+
self.engine.optimizers.zero_grad()
185+
self.engine.lr_schedulers.step()
200186

201-
# self.profiler.step()
202187
self.current_step += 1
203-
204-
# if self.current_step % self.train_config.val_every_n_steps == 0:
205-
# self.validate()
206-
self.checkpointer.save(
188+
self.engine.checkpointer.save(
207189
curr_step=self.current_step,
208190
last_step=self.current_step == self.num_training_steps,
209191
)
210192

193+
# TODO: integrate the grpo app step with the above step
194+
# def train_step(self, self, batch: list(Episode)):
195+
# total_loss = 0.0
196+
# num_groups_processed = 0
197+
#
198+
# for episode in batch:
199+
# groups = episode.groups
200+
#
201+
# # Collect all response texts and corresponding data
202+
# response_texts = []
203+
# ref_logprobs_list = []
204+
# advantages_list = []
205+
#
206+
# for group in groups:
207+
# response_texts.append(group.response)
208+
# ref_logprobs_list.append(group.ref_logprobs)
209+
# advantages_list.append(group.advantage)
210+
#
211+
# # Tokenize all responses in batch
212+
# tokenized = self.tokenizer(
213+
# response_texts,
214+
# padding=True,
215+
# truncation=True,
216+
# return_tensors="pt",
217+
# max_length=512, # Adjust based on your needs
218+
# )
219+
#
220+
# input_ids = tokenized["input_ids"].to(self.device)
221+
# attention_mask = tokenized["attention_mask"].to(self.device)
222+
#
223+
# # Compute current policy log probabilities using the model
224+
# current_logprobs = compute_sequence_logprobs(
225+
# self.model, input_ids, attention_mask, requires_grad=True
226+
# )
227+
#
228+
# # Convert ref_logprobs and advantages to tensors
229+
# ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device)
230+
# advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to(
231+
# self.device
232+
# )
233+
#
234+
# # Compute GRPO loss components
235+
# # Ratio between current policy and reference policy
236+
# ratio = torch.exp(current_logprobs - ref_logprobs_tensor)
237+
#
238+
# # Policy gradient loss weighted by advantages
239+
# pg_loss = -torch.mean(ratio * advantages_tensor)
240+
#
241+
# # KL penalty to prevent policy from deviating too far from reference
242+
# kl_penalty = self.beta * torch.mean(
243+
# (current_logprobs - ref_logprobs_tensor) ** 2
244+
# )
245+
#
246+
# # Total GRPO loss
247+
# loss = pg_loss + kl_penalty
248+
# total_loss += loss.item()
249+
# num_groups_processed += len(groups)
250+
#
251+
# self.optimizer.zero_grad()
252+
# loss.backward()
253+
#
254+
# # Gradient clipping (optional but recommended for stability)
255+
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
256+
#
257+
# self.optimizer.step()
258+
#
259+
# avg_loss = total_loss / len(batch) if batch else 0.0
260+
#
261+
# return {"loss": avg_loss, "groups_processed": num_groups_processed}
262+
211263
@endpoint
212264
def push_weights(self) -> None:
213265
pass
214266

215267
@endpoint
216268
async def cleanup(self) -> None:
217-
# self.pbar.close()
218-
if self.checkpointer:
219-
self.checkpointer.close()
220-
221-
def __repr__(self) -> str:
222-
return "Trainer"
269+
if self.engine.checkpointer:
270+
self.engine.checkpointer.close()

0 commit comments

Comments
 (0)