Skip to content

Conversation

allenwang28
Copy link
Contributor

@allenwang28 allenwang28 commented Aug 7, 2025

This implements the RLTrainer actor following the pattern from apps/sft_v2. Currently this PR works with setup and implements train_step. I'm putting up this PR to land the trainer actor and unblock weight sync work. This is still missing RLLoss which is blocked on a update in ForgeEngine. After this lands and the ForgeEngine update, this will be integrated into apps/grpo (pending #58) and apps/rl will be removed.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Aug 7, 2025
@pbontrager pbontrager self-assigned this Aug 8, 2025
@pbontrager pbontrager changed the title [wip] working towards RL prototype... RLTrainer Aug 26, 2025
@pbontrager pbontrager marked this pull request as ready for review August 26, 2025 20:36
Comment on lines +106 to +114
if env:

def setup(): # noqa: FB811
for k, v in env.items():
os.environ[k] = v

p = await ProcMesh.from_alloc(alloc, setup=setup)
else:
p = await ProcMesh.from_alloc(alloc)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: No strong pref

Suggested change
if env:
def setup(): # noqa: FB811
for k, v in env.items():
os.environ[k] = v
p = await ProcMesh.from_alloc(alloc, setup=setup)
else:
p = await ProcMesh.from_alloc(alloc)
def setup(): # noqa: FB811
if env:
for k, v in env.items():
os.environ[k] = v
p = await ProcMesh.from_alloc(alloc, setup=setup)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to pass None instead of a noop function in case it's faster, and the linter didn't let me conditionally define setup as None or a function.

Comment on lines -16 to -18
from forge.controller import spawn_actors
from forge.controller.service import ServiceConfig
from forge.controller.spawn import spawn_service
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I blame my linter for missing this

Thanks

Comment on lines +58 to +60
@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(run(cfg))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the indirection to run because parse doesn't play well with async? vs @parse on run

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably not have run run but I left it this way for the reason you said


async def run(cfg: DictConfig):
trainer, buffer = await asyncio.gather(
spawn_actors(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't matter for the sake of the apps, but why not spawn_services?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I had this before spawn_services and the next step is integration with grpo so it wasn't worth updating

tokenizer: Tokenizer
train_dataloader: Dataloader
# val_dataloader: Dataloader
metric_logger: MetricLogger
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
metric_logger: MetricLogger
metric_logger: MetricLogger | None

return loss

@endpoint
def train_step(self, batch) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not in this PR since things are still moving, but let's type batch

for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now
self.train_step(batch)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe I'm misreading my indents, but train_step calls train_step?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot to remove that line

)

@endpoint
def push_weights(self) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: slightly more descriptive name here.



async def get_proc_mesh(process_config: ProcessConfig) -> ProcMesh:
async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> ProcMesh:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This does the automatic GPU alloc right? Is there a reference from Ray on how they do this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just adding distributed variables that need to be shared across the proc_mesh. This should also be handled by the host_mesh setup in the future.

self._init_dist()
super().__init__(job_config)

def _init_dist(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should probably go inside Monarch or a general helper function at some point

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is brought over from sft_v2. There is a monarch helper that landed yesterday, but in general we're waiting for hostmesh support so we can handle this more robustly.

Copy link
Contributor

@ebsmothers ebsmothers left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A bunch of small comments, stamping to unblock

num_hosts: 1
num_procs: 1

# policy:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove from here on down?

apps/rl/main.py Outdated
"""A working example showcasing a practical example of forge with RL.
Run this with:
HF_HUB_DISABLE_XET=1 python -m apps.rl.main --config apps/rl/llama3_8b.yaml
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need HF_HUB_DISABLE_XET here, we aren't downloading anything?

apps/rl/main.py Outdated
import logging
import sys

# from forge.actors import Policy, PolicyRouter, RLTrainer, ReplayBuffer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove


__all__ = ["Collector"]

def __getattr__(name):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You trying to sneak this in or what? 👀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NO

pyproject.toml Outdated
"torch==2.9.0.dev20250826",
"monarch-no-torch==0.1.0.dev20250826",
]
# cpu = [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this about?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing, I forgot it there. It's broken with that addition but it needs to be fixed separate


# apply context parallelism if cp is enabled
# ensure CP handles the separate freqs_cis buffer for each pp stage
inputs = input_dict["tokens"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fyi for any models with flex you will need to add the changes from #76

# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")

self.optimizers.step()
self.optimizers.zero_grad()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👀

@endpoint
def train_step(self, batch) -> None:
# Move tensors to the appropriate device
for k, v in batch.items():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we have your utility batch_to_device now

)

@endpoint
def push_weights(self) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: add a comment explaining what this will be used for

actor_cls: ForgeActor,
cfg: DictConfig,
processes: ProcessConfig,
set_address: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why/when do we actually need this? (I saw you use it in the spawn for ForgeSFTRecipe, but why there and nowhere else?)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes will propagate, but you need a proper way to set a unique communication address per proc_mesh. This isn't at the actor level. This will be used for any mesh that uses distributed apis, including policy

@pbontrager pbontrager merged commit 13fd9f1 into meta-pytorch:main Aug 27, 2025
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants