diff --git a/apps/rl/llama3_8b.yaml b/apps/rl/llama3_8b.yaml index 58111ebb8..49ffe3e47 100644 --- a/apps/rl/llama3_8b.yaml +++ b/apps/rl/llama3_8b.yaml @@ -33,9 +33,11 @@ trainer: seq_len: 2048 max_norm: 1.0 steps: 5 - compile: false dataset: "c4" + compile: + enable: false + parallelism: data_parallel_replicate_degree: 1 data_parallel_shard_degree: -1 diff --git a/apps/rl/main.py b/apps/rl/main.py index c16bdb04c..3ff8cea37 100644 --- a/apps/rl/main.py +++ b/apps/rl/main.py @@ -30,7 +30,7 @@ async def run(cfg: DictConfig): spawn_actors( name="trainer", actor_cls=RLTrainer, - cfg={"config": cfg.trainer}, + cfg=cfg.trainer, processes=cfg.trainer.pop("processes"), set_address=True, ), diff --git a/pyproject.toml b/pyproject.toml index 3113812a5..ccb0cb12c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "tokenizers", # Miscellaneous "omegaconf", + "wandb", ] dynamic = ["version"] diff --git a/src/forge/actors/trainer.py b/src/forge/actors/trainer.py index 30ed1b69d..4232ca5ca 100644 --- a/src/forge/actors/trainer.py +++ b/src/forge/actors/trainer.py @@ -8,123 +8,111 @@ import logging import math import os -from typing import Any +from collections.abc import Mapping +from dataclasses import dataclass, field, fields import torch -import torchtitan.experiments.forge.train_spec as forge_train_spec from monarch.actor import current_rank, current_size, endpoint -from omegaconf import DictConfig, OmegaConf -from torch import nn -from torchtitan.components.loss import LossFunction - -# from torchdata.stateful_dataloader import StatefulDataLoader -# from torchtitan.components.checkpoint import ModelWrapper -from torchtitan.components.lr_scheduler import LRSchedulersContainer -from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.distributed import ParallelDims, utils as dist_utils +from torchtitan.config.job_config import ( + ActivationCheckpoint, + Checkpoint, + Comm, + Compile, + Float8, + LRScheduler, + Model, + Optimizer, + Parallelism, + Training, +) + +from torchtitan.distributed import utils as dist_utils from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -# from tqdm import tqdm - from forge.controller import ForgeActor -# from forge.interfaces import RLLoss - -# stubs for now -Checkpointer = Any -Dataloader = Any -MetricLogger = Any -Profiler = Any -Tokenizer = Any - logger = logging.getLogger(__name__) logger.setLevel(logging.INFO) -class RLTrainer(ForgeActor, ForgeEngine): - job_config: ForgeJobConfig - train_spec: forge_train_spec.ForgeTrainSpec - parallel_dims: ParallelDims - model: list[nn.Module] - loss_fn: LossFunction - optimizer: OptimizersContainer - lr_scheduler: LRSchedulersContainer - checkpointer: Checkpointer - tokenizer: Tokenizer - train_dataloader: Dataloader - # val_dataloader: Dataloader - profiler: Profiler - device: torch.device - step: int - - def __init__(self, config: DictConfig): - job_config = ForgeJobConfig().to_dict() - # Hack to deal with literal types from titan - job_config = OmegaConf.merge(job_config, config) - - self.current_step = 0 - self.num_training_steps = job_config.training.steps - self.gradient_accumulation_steps = 1 # Example value, adjust as needed - self._rank = current_rank().rank - self._size = math.prod(current_size().values()) - self._init_dist() - super().__init__(job_config) - - def _init_dist(self): - """Initializes torch distributed. - - torchrun normally hands this, but we need to do it ourselves +@dataclass +class RLTrainer(ForgeActor): + model: Model = field(default_factory=Model) + optimizer: Optimizer = field(default_factory=Optimizer) + lr_scheduler: LRScheduler = field(default_factory=LRScheduler) + training: Training = field(default_factory=Training) + parallelism: Parallelism = field(default_factory=Parallelism) + checkpoint: Checkpoint = field(default_factory=Checkpoint) + activation_checkpoint: ActivationCheckpoint = field( + default_factory=ActivationCheckpoint + ) + compile: Compile = field(default_factory=Compile) + float8: Float8 = field(default_factory=Float8) + comm: Comm = field(default_factory=Comm) + + def __post_init__(self): + """Initializes config types and env variables. + + torchrun normally hands env variables, but we need to do it ourselves in monarch for now. - We should consider putting this into ForgeActor, but having this - be explicit for now. - """ + # Instantiate dict fields + for f in fields(self): + attr = getattr(self, f.name) + if isinstance(attr, Mapping): + setattr(self, f.name, f.type(**attr)) + elif not isinstance(attr, f.type): + raise TypeError( + f"{f.name} should be a {f.type} type or a dict like object" + ) + + self.current_step = 0 + self.num_training_steps = self.training.steps + self.gradient_accumulation_steps = 1 + self.rank = current_rank().rank + self.size = math.prod(current_size().values()) + env = { - "RANK": str(self._rank), - "LOCAL_RANK": str(self._rank), - "LOCAL_WORLD_SIZE": str(self._size), - "GROUP_RANK": str(self._size), - "GROUP_WORLD_SIZE": str(self._size), - "ROLE_RANK": str(self._rank), - "ROLE_WORLD_SIZE": str(self._size), + "RANK": str(self.rank), + "LOCAL_RANK": str(self.rank), + "LOCAL_WORLD_SIZE": str(self.size), + "GROUP_RANK": str(self.size), + "GROUP_WORLD_SIZE": str(self.size), + "ROLE_RANK": str(self.rank), + "ROLE_WORLD_SIZE": str(self.size), "ROLE_NAME": "rank", - "WORLD_SIZE": str(self._size), + "WORLD_SIZE": str(self.size), "PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", } os.environ.update(env) - logger.info("env: {}".format(env)) @endpoint async def setup(self): - self.checkpointer.load(step=self.current_step) - # self.profiler = self.setup_profiler(self.train_config.profiler_config) - # self.logger = self.setup_logger(self.train_config.logger_config) - self.optimizers.zero_grad() - - # self.pbar = tqdm( - # initial=0, - # total=self.num_training_steps, - # desc=f"{self.current_step}", - # ) - # + # TODO: update ForgeEngine to not use ForgeJobConfig + engine_config = {f.name: getattr(self, f.name) for f in fields(self)} + self.engine = ForgeEngine(ForgeJobConfig(**engine_config)) + self.engine.checkpointer.load(step=self.current_step) + self.engine.optimizers.zero_grad() def forward_backward( self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims + model_parts = self.engine.model_parts + parallel_dims = self.engine.parallel_dims # apply context parallelism if cp is enabled # ensure CP handles the separate freqs_cis buffer for each pp stage inputs = input_dict["tokens"] - if getattr(self.model_args, "use_flex_attn", False): + if getattr(self.engine.model_args, "use_flex_attn", False): cp_mesh = ( parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None ) - init_attention_mask(inputs, self.tokenizer.base_tokenizer.eos_id, cp_mesh) + init_attention_mask( + inputs, self.engine.tokenizer.base_tokenizer.eos_id, cp_mesh + ) optional_context_parallel_ctx = ( dist_utils.create_context_parallel_ctx( @@ -164,11 +152,11 @@ def forward_backward( # ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): + with self.engine.train_context(optional_context_parallel_ctx): assert len(model_parts) == 1 - with self.maybe_enable_amp: + with self.engine.maybe_enable_amp: pred = model_parts[0](inputs) - loss = self.loss_fn(pred, labels) + loss = self.engine.loss_fn(pred, labels) # need to free to before bwd to avoid peaking memory del pred loss.backward() @@ -191,32 +179,92 @@ def train_step(self, batch) -> None: # TODO: convert to GRPO Loss labels = batch.pop("labels") loss = self.forward_backward(batch, labels) - # self.pbar.update(1) - # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") - self.optimizers.step() - self.optimizers.zero_grad() - self.lr_schedulers.step() + self.engine.optimizers.step() + self.engine.optimizers.zero_grad() + self.engine.lr_schedulers.step() - # self.profiler.step() self.current_step += 1 - - # if self.current_step % self.train_config.val_every_n_steps == 0: - # self.validate() - self.checkpointer.save( + self.engine.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, ) + # TODO: integrate the grpo app step with the above step + # def train_step(self, self, batch: list(Episode)): + # total_loss = 0.0 + # num_groups_processed = 0 + # + # for episode in batch: + # groups = episode.groups + # + # # Collect all response texts and corresponding data + # response_texts = [] + # ref_logprobs_list = [] + # advantages_list = [] + # + # for group in groups: + # response_texts.append(group.response) + # ref_logprobs_list.append(group.ref_logprobs) + # advantages_list.append(group.advantage) + # + # # Tokenize all responses in batch + # tokenized = self.tokenizer( + # response_texts, + # padding=True, + # truncation=True, + # return_tensors="pt", + # max_length=512, # Adjust based on your needs + # ) + # + # input_ids = tokenized["input_ids"].to(self.device) + # attention_mask = tokenized["attention_mask"].to(self.device) + # + # # Compute current policy log probabilities using the model + # current_logprobs = compute_sequence_logprobs( + # self.model, input_ids, attention_mask, requires_grad=True + # ) + # + # # Convert ref_logprobs and advantages to tensors + # ref_logprobs_tensor = torch.stack(ref_logprobs_list).to(self.device) + # advantages_tensor = torch.tensor(advantages_list, dtype=torch.float32).to( + # self.device + # ) + # + # # Compute GRPO loss components + # # Ratio between current policy and reference policy + # ratio = torch.exp(current_logprobs - ref_logprobs_tensor) + # + # # Policy gradient loss weighted by advantages + # pg_loss = -torch.mean(ratio * advantages_tensor) + # + # # KL penalty to prevent policy from deviating too far from reference + # kl_penalty = self.beta * torch.mean( + # (current_logprobs - ref_logprobs_tensor) ** 2 + # ) + # + # # Total GRPO loss + # loss = pg_loss + kl_penalty + # total_loss += loss.item() + # num_groups_processed += len(groups) + # + # self.optimizer.zero_grad() + # loss.backward() + # + # # Gradient clipping (optional but recommended for stability) + # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) + # + # self.optimizer.step() + # + # avg_loss = total_loss / len(batch) if batch else 0.0 + # + # return {"loss": avg_loss, "groups_processed": num_groups_processed} + @endpoint def push_weights(self) -> None: pass @endpoint async def cleanup(self) -> None: - # self.pbar.close() - if self.checkpointer: - self.checkpointer.close() - - def __repr__(self) -> str: - return "Trainer" + if self.engine.checkpointer: + self.engine.checkpointer.close()