-
Notifications
You must be signed in to change notification settings - Fork 16
RLTrainer #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RLTrainer #40
Changes from all commits
a06b0ec
d1eab28
8e5c3d3
ba424b7
9e41f2b
33a5cb3
f40d9bb
49e5996
2b35e29
3678d49
f5c9fa0
54d776c
cd83ac0
1105456
28878c6
3ec6f94
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,138 @@ | ||
# Config for GRPO finetuning using a Llama3.1 8B Instruct model | ||
# | ||
# This config assumes that you've run the following command before launching | ||
# this run: | ||
# export HF_HUB_DISABLE_XET=1 | ||
# uv run forge download meta-llama/Meta-Llama-3.1-8B-Instruct | ||
|
||
|
||
trainer: | ||
comm: | ||
trace_buf_size: 0 | ||
|
||
model: | ||
name: llama3 | ||
flavor: 8B | ||
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct | ||
|
||
processes: | ||
scheduler: local # local | mast (not supported yet) | ||
num_hosts: 1 | ||
num_procs: 4 | ||
|
||
optimizer: | ||
name: AdamW | ||
lr: 1e-5 | ||
eps: 1e-8 | ||
|
||
lr_scheduler: | ||
warmup_steps: 1 | ||
|
||
training: | ||
local_batch_size: 1 | ||
seq_len: 2048 | ||
max_norm: 1.0 | ||
steps: 5 | ||
compile: false | ||
dataset: "c4" | ||
|
||
parallelism: | ||
data_parallel_replicate_degree: 1 | ||
data_parallel_shard_degree: -1 | ||
tensor_parallel_degree: 1 | ||
pipeline_parallel_degree: 1 | ||
context_parallel_degree: 1 | ||
expert_parallel_degree: 1 | ||
disable_loss_parallel: false | ||
|
||
checkpoint: | ||
enable_checkpoint: true | ||
folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints | ||
initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
initial_load_in_hf: true | ||
last_save_in_hf: true | ||
interval: 500 | ||
async_mode: "disabled" | ||
|
||
activation_checkpoint: | ||
mode: selective | ||
selective_ac_option: op | ||
|
||
replay_buffer: | ||
batch_size: 2 | ||
max_policy_age: 2 | ||
seed: None | ||
processes: | ||
scheduler: local # local | mast (not supported yet) | ||
num_hosts: 1 | ||
num_procs: 1 | ||
|
||
# policy: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Remove from here on down? |
||
# scheduler: | ||
# scheduler: local # local | mast (not supported yet) | ||
# num_hosts: 1 | ||
# num_gpus: 1 | ||
# oncall: torchtune | ||
# identity: pytorch_distributed | ||
# image: forge_workspace:latest | ||
# | ||
# model: "meta-llama/Llama-3.1-8B-Instruct" | ||
# tensor_parallel_size: 2 | ||
# pipeline_parallel_size: 1 | ||
# enforce_eager: false | ||
|
||
# postprocessor: | ||
# scheduler: | ||
# scheduler: local # local | mast (not supported yet) | ||
# num_hosts: 1 | ||
# num_gpus: 1 | ||
# oncall: torchtune | ||
# identity: pytorch_distributed | ||
# image: forge_workspace:latest | ||
# | ||
# comm: | ||
# trace_buf_size: 0 | ||
# | ||
# optimizer: | ||
# name: AdamW | ||
# lr: 1e-5 | ||
# eps: 1e-8 | ||
# | ||
# lr_scheduler: | ||
# warmup_steps: 1 | ||
# | ||
# training: | ||
# local_batch_size: 1 | ||
# seq_len: 2048 | ||
# max_norm: 1.0 | ||
# steps: 5 | ||
# compile: false | ||
# dataset: "c4" | ||
# | ||
# parallelism: | ||
# data_parallel_replicate_degree: 1 | ||
# data_parallel_shard_degree: -1 | ||
# tensor_parallel_degree: 1 | ||
# pipeline_parallel_degree: 1 | ||
# context_parallel_degree: 1 | ||
# expert_parallel_degree: 1 | ||
# disable_loss_parallel: false | ||
# | ||
# checkpoint: | ||
# enable_checkpoint: true | ||
# folder: /tmp/Meta-Llama-3.1-8B-Instruct/ | ||
# interval: 500 | ||
# async_mode: "disabled" | ||
# | ||
# activation_checkpoint: | ||
# mode: selective | ||
# selective_ac_option: op | ||
# | ||
|
||
# profiling: | ||
# enable_profiling: false | ||
|
||
# metrics: | ||
# log_freq: 10 | ||
# enable_tensorboard: true | ||
# save_tb_folder: "tb" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
"""A working example showcasing a practical example of forge with RL. | ||
Run this with: | ||
python -m apps.rl.main --config apps/rl/llama3_8b.yaml | ||
""" | ||
|
||
import asyncio | ||
import logging | ||
import sys | ||
|
||
from forge.actors import ReplayBuffer, RLTrainer | ||
|
||
from forge.cli.config import parse | ||
from forge.controller import spawn_actors | ||
from omegaconf import DictConfig | ||
|
||
|
||
logger = logging.getLogger(__name__) | ||
logger.setLevel(logging.INFO) | ||
|
||
|
||
async def run(cfg: DictConfig): | ||
trainer, buffer = await asyncio.gather( | ||
spawn_actors( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't matter for the sake of the apps, but why not spawn_services? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had this before spawn_services and the next step is integration with grpo so it wasn't worth updating |
||
name="trainer", | ||
actor_cls=RLTrainer, | ||
cfg={"config": cfg.trainer}, | ||
processes=cfg.trainer.pop("processes"), | ||
set_address=True, | ||
), | ||
spawn_actors( | ||
name="replay_buffer", | ||
actor_cls=ReplayBuffer, | ||
cfg=cfg.replay_buffer, | ||
processes=cfg.replay_buffer.pop("processes"), | ||
), | ||
) | ||
print("Actors spawned") | ||
|
||
# Initialize everything | ||
await asyncio.gather( | ||
buffer.setup.call(), | ||
trainer.setup.call(), | ||
) | ||
print("Setup done") | ||
|
||
print("shutting down...") | ||
await asyncio.gather(*[a.mesh.stop() for a in [trainer]]) | ||
|
||
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) | ||
Comment on lines
+57
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is the indirection to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should probably not have run run but I left it this way for the reason you said |
||
|
||
|
||
if __name__ == "__main__": | ||
sys.exit(recipe_main()) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,7 +16,6 @@ | |
import math | ||
import os | ||
import sys | ||
from dataclasses import asdict | ||
from functools import partial | ||
from typing import Any | ||
|
||
|
@@ -71,7 +70,11 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine): | |
device: torch.device | ||
step: int | ||
|
||
def __init__(self, job_config: ForgeJobConfig): | ||
def __init__(self, config: DictConfig): | ||
job_config = ForgeJobConfig().to_dict() | ||
# Hack to deal with literal types from titan | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit but the hack is on the to_dict(), right? (Also fine to just remove the comment completely) |
||
job_config = OmegaConf.merge(job_config, config) | ||
|
||
self.current_step = 0 | ||
self.num_training_steps = job_config.training.steps | ||
self.metric_logger = None # TODO: fix this | ||
|
@@ -92,8 +95,6 @@ def _init_dist(self): | |
""" | ||
env = { | ||
"MASTER_ADDR": "localhost", | ||
"MASTER_PORT": "12345", | ||
"RANK": str(self._rank), | ||
"LOCAL_RANK": str(self._rank), | ||
"LOCAL_WORLD_SIZE": str(self._size), | ||
|
@@ -103,7 +104,6 @@ def _init_dist(self): | |
"ROLE_WORLD_SIZE": str(self._size), | ||
"ROLE_NAME": "rank", | ||
"WORLD_SIZE": str(self._size), | ||
"CUDA_VISIBLE_DEVICES": str(self._rank), | ||
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True", | ||
} | ||
os.environ.update(env) | ||
|
@@ -280,7 +280,13 @@ def __repr__(self) -> str: | |
async def run(cfg: DictConfig) -> None: | ||
logging.info("Spawing recipe...") | ||
process_cfg = cfg.pop("processes") | ||
recipe = await spawn_actors("sft", ForgeSFTRecipe, cfg, process_cfg) | ||
recipe = await spawn_actors( | ||
"sft", | ||
ForgeSFTRecipe, | ||
{"config": cfg}, | ||
process_cfg, | ||
set_address=True, | ||
) | ||
|
||
logging.info("Created recipe, running setup.") | ||
await recipe.setup.call() | ||
|
@@ -296,11 +302,6 @@ async def run(cfg: DictConfig) -> None: | |
|
||
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
# TODO: this is a hack to get the defaults from ForgeJobConfig | ||
default_cfg = ForgeJobConfig() | ||
# Hack to deal with literal types from titan | ||
default_cfg = asdict(default_cfg) | ||
cfg = OmegaConf.merge(default_cfg, cfg) | ||
asyncio.run(run(cfg)) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,25 @@ | |
# This source code is licensed under the BSD-style license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
|
||
from .collector import Collector | ||
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"] | ||
|
||
__all__ = ["Collector"] | ||
|
||
def __getattr__(name): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You trying to sneak this in or what? 👀 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. NO |
||
if name == "Policy": | ||
from .policy import Policy | ||
|
||
return Policy | ||
elif name == "PolicyRouter": | ||
from .policy import PolicyRouter | ||
|
||
return PolicyRouter | ||
elif name == "RLTrainer": | ||
from .trainer import RLTrainer | ||
|
||
return RLTrainer | ||
elif name == "ReplayBuffer": | ||
from .replay_buffer import ReplayBuffer | ||
|
||
return ReplayBuffer | ||
else: | ||
raise AttributeError(f"module {__name__} has no attribute {name}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
super nit but for OSS I would leave this out. Don't subject others to the idiosyncracies of our infrastructure
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently leaving comments to make our internal use as fast as possible and will remove before oss