-
Notifications
You must be signed in to change notification settings - Fork 16
RLTrainer #40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RLTrainer #40
Conversation
if env: | ||
|
||
def setup(): # noqa: FB811 | ||
for k, v in env.items(): | ||
os.environ[k] = v | ||
|
||
p = await ProcMesh.from_alloc(alloc, setup=setup) | ||
else: | ||
p = await ProcMesh.from_alloc(alloc) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: No strong pref
if env: | |
def setup(): # noqa: FB811 | |
for k, v in env.items(): | |
os.environ[k] = v | |
p = await ProcMesh.from_alloc(alloc, setup=setup) | |
else: | |
p = await ProcMesh.from_alloc(alloc) | |
def setup(): # noqa: FB811 | |
if env: | |
for k, v in env.items(): | |
os.environ[k] = v | |
p = await ProcMesh.from_alloc(alloc, setup=setup) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to pass None instead of a noop function in case it's faster, and the linter didn't let me conditionally define setup as None or a function.
from forge.controller import spawn_actors | ||
from forge.controller.service import ServiceConfig | ||
from forge.controller.spawn import spawn_service |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I blame my linter for missing this
Thanks
@parse | ||
def recipe_main(cfg: DictConfig) -> None: | ||
asyncio.run(run(cfg)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the indirection to run
because parse doesn't play well with async? vs @parse on run
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should probably not have run run but I left it this way for the reason you said
|
||
async def run(cfg: DictConfig): | ||
trainer, buffer = await asyncio.gather( | ||
spawn_actors( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Doesn't matter for the sake of the apps, but why not spawn_services?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had this before spawn_services and the next step is integration with grpo so it wasn't worth updating
src/forge/actors/trainer.py
Outdated
tokenizer: Tokenizer | ||
train_dataloader: Dataloader | ||
# val_dataloader: Dataloader | ||
metric_logger: MetricLogger |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
metric_logger: MetricLogger | |
metric_logger: MetricLogger | None |
return loss | ||
|
||
@endpoint | ||
def train_step(self, batch) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in this PR since things are still moving, but let's type batch
src/forge/actors/trainer.py
Outdated
for k, v in batch.items(): | ||
if isinstance(v, torch.Tensor): | ||
batch[k] = v.to("cuda") # TODO: hardcoded for now | ||
self.train_step(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe I'm misreading my indents, but train_step calls train_step?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot to remove that line
) | ||
|
||
@endpoint | ||
def push_weights(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: slightly more descriptive name here.
|
||
|
||
async def get_proc_mesh(process_config: ProcessConfig) -> ProcMesh: | ||
async def get_proc_mesh(process_config: ProcessConfig, set_address=False) -> ProcMesh: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This does the automatic GPU alloc right? Is there a reference from Ray on how they do this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just adding distributed variables that need to be shared across the proc_mesh. This should also be handled by the host_mesh setup in the future.
self._init_dist() | ||
super().__init__(job_config) | ||
|
||
def _init_dist(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should probably go inside Monarch or a general helper function at some point
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is brought over from sft_v2. There is a monarch helper that landed yesterday, but in general we're waiting for hostmesh support so we can handle this more robustly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
A bunch of small comments, stamping to unblock
num_hosts: 1 | ||
num_procs: 1 | ||
|
||
# policy: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove from here on down?
apps/rl/main.py
Outdated
"""A working example showcasing a practical example of forge with RL. | ||
Run this with: | ||
HF_HUB_DISABLE_XET=1 python -m apps.rl.main --config apps/rl/llama3_8b.yaml |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need HF_HUB_DISABLE_XET here, we aren't downloading anything?
apps/rl/main.py
Outdated
import logging | ||
import sys | ||
|
||
# from forge.actors import Policy, PolicyRouter, RLTrainer, ReplayBuffer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
|
||
__all__ = ["Collector"] | ||
|
||
def __getattr__(name): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You trying to sneak this in or what? 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NO
pyproject.toml
Outdated
"torch==2.9.0.dev20250826", | ||
"monarch-no-torch==0.1.0.dev20250826", | ||
] | ||
# cpu = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this about?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing, I forgot it there. It's broken with that addition but it needs to be fixed separate
|
||
# apply context parallelism if cp is enabled | ||
# ensure CP handles the separate freqs_cis buffer for each pp stage | ||
inputs = input_dict["tokens"] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fyi for any models with flex you will need to add the changes from #76
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}") | ||
|
||
self.optimizers.step() | ||
self.optimizers.zero_grad() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👀
@endpoint | ||
def train_step(self, batch) -> None: | ||
# Move tensors to the appropriate device | ||
for k, v in batch.items(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: we have your utility batch_to_device now
) | ||
|
||
@endpoint | ||
def push_weights(self) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: add a comment explaining what this will be used for
actor_cls: ForgeActor, | ||
cfg: DictConfig, | ||
processes: ProcessConfig, | ||
set_address: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why/when do we actually need this? (I saw you use it in the spawn for ForgeSFTRecipe, but why there and nowhere else?)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changes will propagate, but you need a proper way to set a unique communication address per proc_mesh. This isn't at the actor level. This will be used for any mesh that uses distributed apis, including policy
This implements the RLTrainer actor following the pattern from apps/sft_v2. Currently this PR works with setup and implements train_step. I'm putting up this PR to land the trainer actor and unblock weight sync work. This is still missing RLLoss which is blocked on a update in ForgeEngine. After this lands and the ForgeEngine update, this will be integrated into apps/grpo (pending #58) and apps/rl will be removed.