Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ __pycache__/
# Cython debug symbols
cython_debug/

# SLURM logs
slogs/
slurm-*

# Celery stuff
celerybeat-schedule
celerybeat.pid
Expand Down
5 changes: 5 additions & 0 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from forge.actors.replay_buffer import ReplayBuffer
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
Expand Down Expand Up @@ -480,7 +481,11 @@ async def continuous_training():
shutdown_service(compute_advantages),
shutdown_service(ref_model),
shutdown_service(reward_actor),
return_exceptions=True,
)
# TODO - add a global shutdown that implicitly shuts down all services
# and remote allocations
await shutdown()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be awesome!



@parse
Expand Down
79 changes: 79 additions & 0 deletions apps/grpo/multihost.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
# GRPO Training Configuration
# Currently a fork of the main yaml, this just shows
# placement of trainer and inference servers on separate hosts.

# Global configuration
group_size: 4
batch_size: 4
max_req_tokens: 512
max_res_tokens: 128
model: "Qwen/Qwen3-1.7B-Base"

# Dataset configuration
dataset:
path: "openai/gsm8k"
revision: "main"
data_split: "train"
streaming: true
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Policy configuration
policy:
engine_config:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
n: 4
max_tokens: 128
temperature: 1.0
top_p: 1.0
service:
hosts_per_replica: 1 # Places on a remote node
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Trainer configuration
trainer:
learning_rate: 1e-5
service:
hosts_per_replica: 1 # Places on a remote node
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 0
dp_size: 1
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Compute advantages configuration
compute_advantages:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Reference model configuration
ref_model:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Reward actor configuration
reward_actor:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
22 changes: 22 additions & 0 deletions apps/vllm/deepseek_r1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# NOTE - this won't work until we have proper HostMesh support
policy:
engine_config:
model: "deepseek-ai/DeepSeek-R1-0528"
tensor_parallel_size: 16
pipeline_parallel_size: 1
enable_expert_parallel: true
# enforce_eager: true
sampling_config:
n: 2
guided_decoding: false
max_tokens: 512
available_devices: null
service:
procs_per_replica: 8
hosts_per_replica: 2
num_replicas: 1
with_gpus: true


# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
39 changes: 23 additions & 16 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@

import asyncio

import os

from forge.actors.policy import Policy
from forge.cli.config import parse
from forge.controller.provisioner import shutdown
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service

from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from vllm.outputs import RequestOutput

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this snuck in from BS that I did and might not be necessary. Unless you found something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so this app isn't running with torchstore, it's that when I try and run DeepSeekv3 it takes like an hour to download the weights. I added this here so it doesn't crash

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, this timeout is decoupled from the e2e latency of an endpoint

os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"


async def run(cfg: DictConfig):

Expand All @@ -27,28 +33,29 @@ async def run(cfg: DictConfig):
prompt = "What is 3+5?" if gd else "Tell me a joke"

print("Spawning service...")

policy = await spawn_service(
ServiceConfig(**cfg.policy.service),
Policy,
**exclude_service(cfg.policy),
)

async with policy.session():
print("Requesting generation...")
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)

print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(response_output.outputs):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)

print("\nShutting down...")

await shutdown_service(policy)
try:
async with policy.session():
print("Requesting generation...")
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)

print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(response_output.outputs):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)

print("\nShutting down...")
finally:
await shutdown_service(policy)
await shutdown()


@parse
Expand Down
20 changes: 20 additions & 0 deletions apps/vllm/qwen2_5_32b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
policy:
engine_config:
model: "Qwen/Qwen2.5-32B"
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
n: 2
guided_decoding: false
max_tokens: 512
available_devices: null
service:
procs_per_replica: 4
num_replicas: 1
hosts_per_replica: 1
with_gpus: true


# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
13 changes: 13 additions & 0 deletions launcher/job.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that torchx meant we didn't need this? Either way, I think this should be in the GRPO app and take the config still as a conditional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's convoluted, we run sbatch to schedule the controller, then the controller calls sbatch through torchx.

We need the controller to run on a GPU node so that Monarch's build doesn't complain because it's built with tensor engine. I'm not sure what the right long-term solution is quite yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to leave it for now, but will think on this more

#SBATCH --job-name=forge
#SBATCH --output=slogs/forge.out
#SBATCH --error=slogs/forge.err
#SBATCH --partition=h100-low # or h100-high / h100-prod / all
#SBATCH --nodes=1 # 1 node
#SBATCH --ntasks=1 # 1 task (process)
#SBATCH --gres=gpu:8 # request 8 GPUs
#SBATCH --time=01:00:00 # walltime hh:mm:ss

unset SLURM_MEM_PER_CPU SLURM_MEM_PER_GPU SLURM_MEM_PER_NODE
echo "Running on $SLURM_JOB_NODELIST"
python -m apps.grpo.main --config=apps/grpo/multihost.yaml
35 changes: 22 additions & 13 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this is needed for EngineConfig so we can do

@classmethod
def as_engine_args(cls, config: Mapping | EngineConfig) ...

otherwise it'd need to be like

@classmethod
def as_engine_args(cls, config: Mapping | "EngineConfig") ...

and Python complains about the latter


import asyncio
import logging
import os
Expand Down Expand Up @@ -92,6 +94,7 @@ class EngineConfig(EngineArgs):
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
enforce_eager: bool = False
enable_expert_parallel: bool = False

@classmethod
def from_dict(cls, d: Mapping):
Expand All @@ -100,6 +103,16 @@ def from_dict(cls, d: Mapping):
valid_args = {k: v for k, v in d.items() if k in all_fields}
return cls(**valid_args)

@classmethod
def as_engine_args(cls, config: Mapping | EngineConfig) -> EngineConfig:
if isinstance(config, Mapping):
config = EngineConfig.from_dict(config)

# Original method returns False when not run in the main thread
config._is_v1_supported_oracle = lambda *_: True
# Build Config
return config.create_engine_config(UsageContext.LLM_CLASS)


@dataclass
class Policy(PolicyInterface):
Expand Down Expand Up @@ -138,9 +151,15 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
# automatically.
worker_procs = await get_proc_mesh(process_config=process_config)

# TODO - we will want to ensure colocation with workers
# TODO - issues/144 we will want to ensure colocation with workers
# We're currently locating the Policy on the local host proc mesh
# vLLM initialization without setting env variables at proc_mesh creation
# level leads to issues.
# Once we can create multiple proc meshes on a host mesh, we can ensure
# host colocation
policy_proc_config = copy(process_config)
policy_proc_config.num_procs = 1
policy_proc_config.num_hosts = None
policy_proc_config.with_gpus = False

policy_proc = await get_proc_mesh(process_config=policy_proc_config)
Expand Down Expand Up @@ -196,7 +215,7 @@ async def setup(self):

self.request_id = 0
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_args = await self.policy_worker.get_vllm_args.choose()
self.vllm_args = EngineConfig.as_engine_args(self.engine_config)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note, I added as_engine_args because when I was trying DeepSeek I saw some weird pickling issues when trying to do the get_vllm_args.choose(). I previously spent several hours trying to debug it until I decided it wasn't worth it.


# Setup sampling params
self.sampling_params = get_default_sampling_params(
Expand Down Expand Up @@ -382,13 +401,7 @@ def __post_init__(self):
- all LLM generate methods, verify against LLM inputs
- all executor methods verify no changes
"""
if isinstance(self.vllm_args, Mapping):
self.vllm_args = EngineConfig.from_dict(self.vllm_args)

# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
# Build Config
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)
self.vllm_args = EngineConfig.as_engine_args(self.vllm_args)

@endpoint
async def setup(self, store: MultiProcessStore = None):
Expand Down Expand Up @@ -476,10 +489,6 @@ async def setup_kv_cache(self):
self.worker.initialize_cache(kv_cache_config.num_blocks, 0)
return kv_cache_config

@endpoint
async def get_vllm_args(self):
return self.vllm_args

@endpoint
async def _get_model_params(self) -> Dict[str, torch.Tensor]:
model = self.worker.model_runner.model
Expand Down
25 changes: 25 additions & 0 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@ async def setup(self):
"""
pass

@endpoint
async def set_env(self, addr: str, port: str):
"""A temporary workaround to set master addr/port.
TODO - issues/144. This should be done in proc_mesh creation.
The ideal path:
- Create a host mesh
- Grab a host from host mesh, from proc 0 spawn an actor that
gets addr/port
- Spawn procs on the HostMesh with addr/port, setting the
addr/port in bootstrap.
We can't currently do this because HostMesh only supports single
proc_mesh creation at the moment. This will be possible once
we have "proper HostMesh support".
"""
import os

os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = port

@classmethod
async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor":
"""Provisions and deploys a new actor.
Expand All @@ -77,6 +99,9 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor
actor = await proc_mesh.spawn(actor_name, cls, **kwargs)
actor._proc_mesh = proc_mesh

if hasattr(proc_mesh, "_hostname") and hasattr(proc_mesh, "_port"):
host, port = proc_mesh._hostname, proc_mesh._port
await actor.set_env.call(addr=host, port=port)
await actor.setup.call()
return actor

Expand Down
Loading
Loading