Skip to content
5 changes: 5 additions & 0 deletions apps/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
138 changes: 138 additions & 0 deletions apps/rl/llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
# Config for GRPO finetuning using a Llama3.1 8B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# export HF_HUB_DISABLE_XET=1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit but for OSS I would leave this out. Don't subject others to the idiosyncracies of our infrastructure

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently leaving comments to make our internal use as fast as possible and will remove before oss

# uv run forge download meta-llama/Meta-Llama-3.1-8B-Instruct


trainer:
comm:
trace_buf_size: 0

model:
name: llama3
flavor: 8B
tokenizer_path: /tmp/Meta-Llama-3.1-8B-Instruct

processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_procs: 4

optimizer:
name: AdamW
lr: 1e-5
eps: 1e-8

lr_scheduler:
warmup_steps: 1

training:
local_batch_size: 1
seq_len: 2048
max_norm: 1.0
steps: 5
compile: false
dataset: "c4"

parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
disable_loss_parallel: false

checkpoint:
enable_checkpoint: true
folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints
initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/
initial_load_in_hf: true
last_save_in_hf: true
interval: 500
async_mode: "disabled"

activation_checkpoint:
mode: selective
selective_ac_option: op

replay_buffer:
batch_size: 2
max_policy_age: 2
seed: None
processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_procs: 1

# policy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove from here on down?

# scheduler:
# scheduler: local # local | mast (not supported yet)
# num_hosts: 1
# num_gpus: 1
# oncall: torchtune
# identity: pytorch_distributed
# image: forge_workspace:latest
#
# model: "meta-llama/Llama-3.1-8B-Instruct"
# tensor_parallel_size: 2
# pipeline_parallel_size: 1
# enforce_eager: false

# postprocessor:
# scheduler:
# scheduler: local # local | mast (not supported yet)
# num_hosts: 1
# num_gpus: 1
# oncall: torchtune
# identity: pytorch_distributed
# image: forge_workspace:latest
#
# comm:
# trace_buf_size: 0
#
# optimizer:
# name: AdamW
# lr: 1e-5
# eps: 1e-8
#
# lr_scheduler:
# warmup_steps: 1
#
# training:
# local_batch_size: 1
# seq_len: 2048
# max_norm: 1.0
# steps: 5
# compile: false
# dataset: "c4"
#
# parallelism:
# data_parallel_replicate_degree: 1
# data_parallel_shard_degree: -1
# tensor_parallel_degree: 1
# pipeline_parallel_degree: 1
# context_parallel_degree: 1
# expert_parallel_degree: 1
# disable_loss_parallel: false
#
# checkpoint:
# enable_checkpoint: true
# folder: /tmp/Meta-Llama-3.1-8B-Instruct/
# interval: 500
# async_mode: "disabled"
#
# activation_checkpoint:
# mode: selective
# selective_ac_option: op
#

# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
63 changes: 63 additions & 0 deletions apps/rl/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""A working example showcasing a practical example of forge with RL.
Run this with:
python -m apps.rl.main --config apps/rl/llama3_8b.yaml
"""

import asyncio
import logging
import sys

from forge.actors import ReplayBuffer, RLTrainer

from forge.cli.config import parse
from forge.controller import spawn_actors
from omegaconf import DictConfig


logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


async def run(cfg: DictConfig):
trainer, buffer = await asyncio.gather(
spawn_actors(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't matter for the sake of the apps, but why not spawn_services?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this before spawn_services and the next step is integration with grpo so it wasn't worth updating

name="trainer",
actor_cls=RLTrainer,
cfg={"config": cfg.trainer},
processes=cfg.trainer.pop("processes"),
set_address=True,
),
spawn_actors(
name="replay_buffer",
actor_cls=ReplayBuffer,
cfg=cfg.replay_buffer,
processes=cfg.replay_buffer.pop("processes"),
),
)
print("Actors spawned")

# Initialize everything
await asyncio.gather(
buffer.setup.call(),
trainer.setup.call(),
)
print("Setup done")

print("shutting down...")
await asyncio.gather(*[a.mesh.stop() for a in [trainer]])


@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(run(cfg))
Comment on lines +57 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the indirection to run because parse doesn't play well with async? vs @parse on run

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably not have run run but I left it this way for the reason you said



if __name__ == "__main__":
sys.exit(recipe_main())
23 changes: 15 additions & 8 deletions apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# Config for supervised full finetuning using a Llama3.1 8B Instruct model
#
# This config assumes that you've run the following command before launching
# this run:
# export HF_HUB_DISABLE_XET=1
# uv run forge download meta-llama/Meta-Llama-3.1-8B-Instruct

# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"

# TODO: required by torchtitan
# https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265
Expand All @@ -20,7 +19,7 @@ model:
processes:
scheduler: local # local | mast (not supported yet)
num_hosts: 1
num_gpus: 8
num_procs: 8

optimizer:
name: AdamW
Expand Down Expand Up @@ -59,3 +58,11 @@ checkpoint:
activation_checkpoint:
mode: selective
selective_ac_option: op

# profiling:
# enable_profiling: false

# metrics:
# log_freq: 10
# enable_tensorboard: true
# save_tb_folder: "tb"
23 changes: 12 additions & 11 deletions apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
import math
import os
import sys
from dataclasses import asdict
from functools import partial
from typing import Any

Expand Down Expand Up @@ -71,7 +70,11 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine):
device: torch.device
step: int

def __init__(self, job_config: ForgeJobConfig):
def __init__(self, config: DictConfig):
job_config = ForgeJobConfig().to_dict()
# Hack to deal with literal types from titan
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but the hack is on the to_dict(), right? (Also fine to just remove the comment completely)

job_config = OmegaConf.merge(job_config, config)

self.current_step = 0
self.num_training_steps = job_config.training.steps
self.metric_logger = None # TODO: fix this
Expand All @@ -92,8 +95,6 @@ def _init_dist(self):
"""
env = {
"MASTER_ADDR": "localhost",
"MASTER_PORT": "12345",
"RANK": str(self._rank),
"LOCAL_RANK": str(self._rank),
"LOCAL_WORLD_SIZE": str(self._size),
Expand All @@ -103,7 +104,6 @@ def _init_dist(self):
"ROLE_WORLD_SIZE": str(self._size),
"ROLE_NAME": "rank",
"WORLD_SIZE": str(self._size),
"CUDA_VISIBLE_DEVICES": str(self._rank),
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
}
os.environ.update(env)
Expand Down Expand Up @@ -280,7 +280,13 @@ def __repr__(self) -> str:
async def run(cfg: DictConfig) -> None:
logging.info("Spawing recipe...")
process_cfg = cfg.pop("processes")
recipe = await spawn_actors("sft", ForgeSFTRecipe, cfg, process_cfg)
recipe = await spawn_actors(
"sft",
ForgeSFTRecipe,
{"config": cfg},
process_cfg,
set_address=True,
)

logging.info("Created recipe, running setup.")
await recipe.setup.call()
Expand All @@ -296,11 +302,6 @@ async def run(cfg: DictConfig) -> None:

@parse
def recipe_main(cfg: DictConfig) -> None:
# TODO: this is a hack to get the defaults from ForgeJobConfig
default_cfg = ForgeJobConfig()
# Hack to deal with literal types from titan
default_cfg = asdict(default_cfg)
cfg = OmegaConf.merge(default_cfg, cfg)
asyncio.run(run(cfg))


Expand Down
4 changes: 2 additions & 2 deletions apps/toy_rl/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch
from forge.actors.collector import Collector

from forge.data.replay_buffer import ReplayBuffer
from forge.actors.replay_buffer import ReplayBuffer
from forge.interfaces import Environment, Policy
from forge.types import Action, Observation, State
from monarch.actor import endpoint, proc_mesh
Expand Down Expand Up @@ -255,7 +255,7 @@ async def replay_buffer_sampler_task():
)

print(
f" Step {i+1:2d}: State={state_value:6.2f} → Action={action_value:6.2f}"
f" Step {i + 1:2d}: State={state_value:6.2f} → Action={action_value:6.2f}"
)

if idx < len(trajectories): # Add spacing between trajectories
Expand Down
23 changes: 21 additions & 2 deletions src/forge/actors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,25 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from .collector import Collector
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"]

__all__ = ["Collector"]

def __getattr__(name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You trying to sneak this in or what? 👀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO

if name == "Policy":
from .policy import Policy

return Policy
elif name == "PolicyRouter":
from .policy import PolicyRouter

return PolicyRouter
elif name == "RLTrainer":
from .trainer import RLTrainer

return RLTrainer
elif name == "ReplayBuffer":
from .replay_buffer import ReplayBuffer

return ReplayBuffer
else:
raise AttributeError(f"module {__name__} has no attribute {name}")
6 changes: 3 additions & 3 deletions src/forge/actors/collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@

from typing import Callable

from forge.data.replay_buffer import ReplayBuffer
from monarch.actor import Actor, endpoint

from forge.actors.replay_buffer import ReplayBuffer

from forge.interfaces import Policy

from forge.types import Trajectory

from monarch.actor import Actor, endpoint


class Collector(Actor):
"""Collects trajectories for the training loop."""
Expand Down
Loading
Loading