Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@ __pycache__/
# Cython debug symbols
cython_debug/

# SLURM logs
slogs/
slurm-*

# Celery stuff
celerybeat-schedule
celerybeat.pid
Expand Down
5 changes: 5 additions & 0 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from forge.actors.trainer import _qwen3_hf_to_vllm
from forge.cli.config import parse
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from forge.data.rewards import MathReward, ThinkingReward
from forge.util.metric_logging import get_metric_logger
Expand Down Expand Up @@ -475,7 +476,11 @@ async def continuous_training():
shutdown_service(compute_advantages),
shutdown_service(ref_model),
shutdown_service(reward_actor),
return_exceptions=True,
)
# TODO - add a global shutdown that implicitly shuts down all services
# and remote allocations
await shutdown()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would be awesome!



if __name__ == "__main__":
Expand Down
81 changes: 81 additions & 0 deletions apps/grpo/qwen3_multinode.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# GRPO Training Configuration
# Currently a fork of the main yaml, this just shows
# placement of trainer and inference servers on separate hosts.

# Global configuration
group_size: 8
batch_size: 16
max_req_tokens: 512
max_res_tokens: 512
model: "Qwen/Qwen3-1.7B"

# Dataset configuration
dataset:
path: "openai/gsm8k"
revision: "main"
data_split: "train"
streaming: true
model: ${model}
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Policy configuration
policy:
engine_config:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
enforce_eager: false
sampling_config:
n: ${group_size}
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
service:
procs_per_replica: 1
hosts_per_replica: 1
num_replicas: 1
with_gpus: true

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
service:
procs_per_replica: 1
hosts_per_replica: 1
num_replicas: 1
with_gpus: true

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 1 # Async by 1
dp_size: 1
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Compute advantages configuration
compute_advantages:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Reference model configuration
ref_model:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Reward actor configuration
reward_actor:
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
22 changes: 22 additions & 0 deletions apps/vllm/deepseek_r1.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# NOTE - this won't work until we have proper HostMesh support
policy:
engine_config:
model: "deepseek-ai/DeepSeek-R1-0528"
tensor_parallel_size: 16
pipeline_parallel_size: 1
enable_expert_parallel: true
# enforce_eager: true
sampling_config:
n: 2
guided_decoding: false
max_tokens: 512
available_devices: null
service:
procs_per_replica: 8
hosts_per_replica: 2
num_replicas: 1
with_gpus: true


# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
33 changes: 20 additions & 13 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,20 @@

import asyncio

import os

from forge.actors.policy import Policy
from forge.cli.config import parse
from forge.controller.provisioner import shutdown
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service

from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from vllm.outputs import RequestOutput

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this snuck in from BS that I did and might not be necessary. Unless you found something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, so this app isn't running with torchstore, it's that when I try and run DeepSeekv3 it takes like an hour to download the weights. I added this here so it doesn't crash

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding, this timeout is decoupled from the e2e latency of an endpoint

os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"


async def run(cfg: DictConfig):

Expand All @@ -27,28 +33,29 @@ async def run(cfg: DictConfig):
prompt = "What is 3+5?" if gd else "Tell me a joke"

print("Spawning service...")

policy = await spawn_service(
ServiceConfig(**cfg.policy.service),
Policy,
**exclude_service(cfg.policy),
)

async with policy.session():
print("Requesting generation...")
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)
try:
async with policy.session():
print("Requesting generation...")
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)

print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(response_output.outputs):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)
print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(response_output.outputs):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
print("-" * 80)

finally:
print("\nShutting down...")

await shutdown_service(policy)
await shutdown_service(policy)
await shutdown()


@parse
Expand Down
20 changes: 20 additions & 0 deletions apps/vllm/qwen2_5_32b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
policy:
engine_config:
model: "Qwen/Qwen2.5-32B"
tensor_parallel_size: 4
pipeline_parallel_size: 1
enforce_eager: true
sampling_config:
n: 2
guided_decoding: false
max_tokens: 512
available_devices: null
service:
procs_per_replica: 4
hosts_per_replica: 1
num_replicas: 1
with_gpus: true


# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
13 changes: 13 additions & 0 deletions launcher/job.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#!/bin/bash
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was under the impression that torchx meant we didn't need this? Either way, I think this should be in the GRPO app and take the config still as a conditional.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah it's convoluted, we run sbatch to schedule the controller, then the controller calls sbatch through torchx.

We need the controller to run on a GPU node so that Monarch's build doesn't complain because it's built with tensor engine. I'm not sure what the right long-term solution is quite yet

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am going to leave it for now, but will think on this more

#SBATCH --job-name=forge
#SBATCH --output=slogs/forge.out
#SBATCH --error=slogs/forge.err
#SBATCH --partition=h100-low # or h100-high / h100-prod / all
#SBATCH --nodes=1 # 1 node
#SBATCH --ntasks=1 # 1 task (process)
#SBATCH --gres=gpu:8 # request 8 GPUs
#SBATCH --time=01:00:00 # walltime hh:mm:ss

unset SLURM_MEM_PER_CPU SLURM_MEM_PER_GPU SLURM_MEM_PER_NODE
echo "Running on $SLURM_JOB_NODELIST"
python -m apps.grpo.main --config=apps/grpo/multihost.yaml
20 changes: 13 additions & 7 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah this is needed for EngineConfig so we can do

@classmethod
def as_engine_args(cls, config: Mapping | EngineConfig) ...

otherwise it'd need to be like

@classmethod
def as_engine_args(cls, config: Mapping | "EngineConfig") ...

and Python complains about the latter


import asyncio
import os
import sys
Expand Down Expand Up @@ -91,6 +93,7 @@ class EngineConfig(EngineArgs):
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
enforce_eager: bool = False
enable_expert_parallel: bool = False

# Original method returns False when not run in the main thread
_is_v1_supported_oracle = lambda *_: True
Expand All @@ -103,7 +106,8 @@ def from_dict(cls, d: Mapping):
return cls(**valid_args)

def create_vllm_config(self) -> VllmConfig:
# This is not a typo: EngineArgs.create_engine_config
"""Converts the current EngineConfig into vLLM's vLLMConfig."""
# Note: EngineArgs.create_engine_config
# creates a VllmConfig
return self.create_engine_config(UsageContext.LLM_CLASS)

Expand Down Expand Up @@ -144,9 +148,15 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
# automatically.
worker_procs = await get_proc_mesh(process_config=process_config)

# TODO - we will want to ensure colocation with workers
# TODO - issues/144 we will want to ensure colocation with workers
# We're currently locating the Policy on the local host proc mesh
# vLLM initialization without setting env variables at proc_mesh creation
# level leads to issues.
# Once we can create multiple proc meshes on a host mesh, we can ensure
# host colocation
policy_proc_config = copy(process_config)
policy_proc_config.num_procs = 1
policy_proc_config.num_hosts = None
policy_proc_config.with_gpus = False

policy_proc = await get_proc_mesh(process_config=policy_proc_config)
Expand Down Expand Up @@ -201,7 +211,7 @@ async def setup(self):
await self.policy_worker.setup.call()

self.request_id = 0
self.requests: Dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.requests: dict[str, tuple[None | ParentRequest, asyncio.Future]] = {}
self.vllm_config: VllmConfig = self.engine_config.create_vllm_config()

# Setup sampling params
Expand Down Expand Up @@ -462,10 +472,6 @@ async def setup_kv_cache(self):
self.worker.initialize_cache(kv_cache_config.num_blocks, 0)
return kv_cache_config

@endpoint
async def get_vllm_config(self) -> VllmConfig:
return self.vllm_config

@endpoint
async def _get_model_params(self) -> dict[str, torch.Tensor]:
model = self.worker.model_runner.model
Expand Down
25 changes: 25 additions & 0 deletions src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,28 @@ async def setup(self):
"""
pass

@endpoint
async def set_env(self, addr: str, port: str):
"""A temporary workaround to set master addr/port.

TODO - issues/144. This should be done in proc_mesh creation.
The ideal path:
- Create a host mesh
- Grab a host from host mesh, from proc 0 spawn an actor that
gets addr/port
- Spawn procs on the HostMesh with addr/port, setting the
addr/port in bootstrap.

We can't currently do this because HostMesh only supports single
proc_mesh creation at the moment. This will be possible once
we have "proper HostMesh support".

"""
import os

os.environ["MASTER_ADDR"] = addr
os.environ["MASTER_PORT"] = port

@classmethod
async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor":
"""Provisions and deploys a new actor.
Expand All @@ -77,6 +99,9 @@ async def launch(cls, *, process_config: ProcessConfig, **kwargs) -> "ForgeActor
actor = await proc_mesh.spawn(actor_name, cls, **kwargs)
actor._proc_mesh = proc_mesh

if hasattr(proc_mesh, "_hostname") and hasattr(proc_mesh, "_port"):
host, port = proc_mesh._hostname, proc_mesh._port
await actor.set_env.call(addr=host, port=port)
await actor.setup.call()
return actor

Expand Down
Loading
Loading