Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion apps/rl/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,11 @@ trainer:
seq_len: 2048
max_norm: 1.0
steps: 5
compile: false
dataset: "c4"

compile:
enable: false

parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
Expand Down
2 changes: 1 addition & 1 deletion apps/rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ async def run(cfg: DictConfig):
spawn_actors(
name="trainer",
actor_cls=RLTrainer,
cfg={"config": cfg.trainer},
cfg=cfg.trainer,
processes=cfg.trainer.pop("processes"),
set_address=True,
),
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ dependencies = [
"tokenizers",
# Miscellaneous
"omegaconf",
"wandb",
]
dynamic = ["version"]

Expand Down
250 changes: 149 additions & 101 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,123 +8,111 @@
import logging
import math
import os
from typing import Any
from collections.abc import Mapping
from dataclasses import dataclass, field, fields

import torch
import torchtitan.experiments.forge.train_spec as forge_train_spec
from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torchtitan.components.loss import LossFunction

# from torchdata.stateful_dataloader import StatefulDataLoader
# from torchtitan.components.checkpoint import ModelWrapper
from torchtitan.components.lr_scheduler import LRSchedulersContainer
from torchtitan.components.optimizer import OptimizersContainer
from torchtitan.distributed import ParallelDims, utils as dist_utils
from torchtitan.config.job_config import (
ActivationCheckpoint,
Checkpoint,
Comm,
Compile,
Float8,
LRScheduler,
Model,
Optimizer,
Parallelism,
Training,
)

from torchtitan.distributed import utils as dist_utils
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

# from tqdm import tqdm

from forge.controller import ForgeActor

# from forge.interfaces import RLLoss

# stubs for now
Checkpointer = Any
Dataloader = Any
MetricLogger = Any
Profiler = Any
Tokenizer = Any

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class RLTrainer(ForgeActor, ForgeEngine):
job_config: ForgeJobConfig
train_spec: forge_train_spec.ForgeTrainSpec
parallel_dims: ParallelDims
model: list[nn.Module]
loss_fn: LossFunction
optimizer: OptimizersContainer
lr_scheduler: LRSchedulersContainer
checkpointer: Checkpointer
tokenizer: Tokenizer
train_dataloader: Dataloader
# val_dataloader: Dataloader
profiler: Profiler
device: torch.device
step: int

def __init__(self, config: DictConfig):
job_config = ForgeJobConfig().to_dict()
# Hack to deal with literal types from titan
job_config = OmegaConf.merge(job_config, config)

self.current_step = 0
self.num_training_steps = job_config.training.steps
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())
self._init_dist()
super().__init__(job_config)

def _init_dist(self):
"""Initializes torch distributed.

torchrun normally hands this, but we need to do it ourselves
@dataclass
class RLTrainer(ForgeActor):
model: Model = field(default_factory=Model)
optimizer: Optimizer = field(default_factory=Optimizer)
lr_scheduler: LRScheduler = field(default_factory=LRScheduler)
training: Training = field(default_factory=Training)
parallelism: Parallelism = field(default_factory=Parallelism)
checkpoint: Checkpoint = field(default_factory=Checkpoint)
activation_checkpoint: ActivationCheckpoint = field(
default_factory=ActivationCheckpoint
)
compile: Compile = field(default_factory=Compile)
float8: Float8 = field(default_factory=Float8)
comm: Comm = field(default_factory=Comm)

def __post_init__(self):
"""Initializes config types and env variables.

torchrun normally hands env variables, but we need to do it ourselves
in monarch for now.

We should consider putting this into ForgeActor, but having this
be explicit for now.

"""
# Instantiate dict fields
for f in fields(self):
attr = getattr(self, f.name)
if isinstance(attr, Mapping):
setattr(self, f.name, f.type(**attr))
elif not isinstance(attr, f.type):
raise TypeError(
f"{f.name} should be a {f.type} type or a dict like object"
)

self.current_step = 0
self.num_training_steps = self.training.steps
self.gradient_accumulation_steps = 1
self.rank = current_rank().rank
self.size = math.prod(current_size().values())

env = {
"RANK": str(self._rank),
"LOCAL_RANK": str(self._rank),
"LOCAL_WORLD_SIZE": str(self._size),
"GROUP_RANK": str(self._size),
"GROUP_WORLD_SIZE": str(self._size),
"ROLE_RANK": str(self._rank),
"ROLE_WORLD_SIZE": str(self._size),
"RANK": str(self.rank),
"LOCAL_RANK": str(self.rank),
"LOCAL_WORLD_SIZE": str(self.size),
"GROUP_RANK": str(self.size),
"GROUP_WORLD_SIZE": str(self.size),
"ROLE_RANK": str(self.rank),
"ROLE_WORLD_SIZE": str(self.size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self._size),
"WORLD_SIZE": str(self.size),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)
logger.info("env: {}".format(env))

@endpoint
async def setup(self):
self.checkpointer.load(step=self.current_step)
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)
self.optimizers.zero_grad()

# self.pbar = tqdm(
# initial=0,
# total=self.num_training_steps,
# desc=f"{self.current_step}",
# )
#
# TODO: update ForgeEngine to not use ForgeJobConfig
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
self.engine.checkpointer.load(step=self.current_step)
self.engine.optimizers.zero_grad()

def forward_backward(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims
model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims

# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["tokens"]

if getattr(self.model_args, "use_flex_attn", False):
if getattr(self.engine.model_args, "use_flex_attn", False):
cp_mesh = (
parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
)
init_attention_mask(inputs, self.tokenizer.base_tokenizer.eos_id, cp_mesh)
init_attention_mask(
inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh
)

optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
Expand Down Expand Up @@ -164,11 +152,11 @@ def forward_backward(
# )
else:
# Non-PP forward / backward
with self.train_context(optional_context_parallel_ctx):
with self.engine.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
with self.engine.maybe_enable_amp:
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
loss = self.engine.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
Expand All @@ -191,32 +179,92 @@ def train_step(self, batch) -> None:
# TODO: convert to GRPO Loss
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)
# self.pbar.update(1)
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")

self.optimizers.step()
self.optimizers.zero_grad()
self.lr_schedulers.step()
self.engine.optimizers.step()
self.engine.optimizers.zero_grad()
self.engine.lr_schedulers.step()

# self.profiler.step()
self.current_step += 1

# if self.current_step % self.train_config.val_every_n_steps == 0:
# self.validate()
self.checkpointer.save(
self.engine.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)

# TODO: integrate the grpo app step with the above step
# def train_step(self, self, batch: list(Episode)):
# total_loss = 0.0
# num_groups_processed = 0
#
# for episode in batch:
# groups = episode.groups
#
# # Collect all response texts and corresponding data
# response_texts = []
# ref_logprobs_list = []
# advantages_list = []
#
# for group in groups:
# response_texts.append(group.response)
# ref_logprobs_list.append(group.ref_logprobs)
# advantages_list.append(group.advantage)
#
# # Tokenize all responses in batch
# tokenized = self.tokenizer(
# response_texts,
# padding=True,
# truncation=True,
# return_tensors="pt",
# max_length=512, # Adjust based on your needs
# )
#
# input_ids = tokenized["input_ids"].to(self.device)
# attention_mask = tokenized["attention_mask"].to(self.device)
#
# # Compute current policy log probabilities using the model
# current_logprobs = compute_sequence_logprobs(
# self.model, input_ids, attention_mask, requires_grad=True
# )
#
# # Convert ref_logprobs and advantages to tensors
# ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device)
# advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to(
# self.device
# )
#
# # Compute GRPO loss components
# # Ratio between current policy and reference policy
# ratio = torch.exp(current_logprobs - ref_logprobs_tensor)
#
# # Policy gradient loss weighted by advantages
# pg_loss = -torch.mean(ratio * advantages_tensor)
#
# # KL penalty to prevent policy from deviating too far from reference
# kl_penalty = self.beta * torch.mean(
# (current_logprobs - ref_logprobs_tensor) ** 2
# )
#
# # Total GRPO loss
# loss = pg_loss + kl_penalty
# total_loss += loss.item()
# num_groups_processed += len(groups)
#
# self.optimizer.zero_grad()
# loss.backward()
#
# # Gradient clipping (optional but recommended for stability)
# torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
#
# self.optimizer.step()
#
# avg_loss = total_loss / len(batch) if batch else 0.0
#
# return {"loss": avg_loss, "groups_processed": num_groups_processed}

@endpoint
def push_weights(self) -> None:
pass

@endpoint
async def cleanup(self) -> None:
# self.pbar.close()
if self.checkpointer:
self.checkpointer.close()

def __repr__(self) -> str:
return "Trainer"
if self.engine.checkpointer:
self.engine.checkpointer.close()
Loading