Skip to content
Merged
Show file tree
Hide file tree
Changes from 20 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
da21e1d
Add reward interface, math reward, unit tests
DNXie Aug 21, 2025
5c72908
Merge branch 'meta-pytorch:main' into main
DNXie Aug 22, 2025
b4d7a61
Merge branch 'meta-pytorch:main' into main
DNXie Aug 25, 2025
02d77c6
Merge branch 'meta-pytorch:main' into main
DNXie Aug 27, 2025
fd1d38b
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
f79beee
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
d8d775a
Merge branch 'meta-pytorch:main' into main
DNXie Sep 2, 2025
7301e10
Add explicit from_dict methods for PolicyConfig and WorkerConfig
DNXie Sep 2, 2025
64687d9
remove
DNXie Sep 2, 2025
9278d75
remove policy config
DNXie Sep 3, 2025
d2d7107
update grpo.main
DNXie Sep 3, 2025
14b5e4a
fixed dict attribute error, but still a buggy version
DNXie Sep 3, 2025
38f7927
update config
DNXie Sep 3, 2025
2a1e021
for debug
DNXie Sep 3, 2025
412398c
fix the bug
DNXie Sep 3, 2025
d998061
clean up
DNXie Sep 3, 2025
8d38eb8
lint
DNXie Sep 3, 2025
a3e755d
add torchstore to dependencies
DNXie Sep 3, 2025
9dd396b
fix typo
DNXie Sep 3, 2025
935fdc1
remove a test file that causes import error
DNXie Sep 3, 2025
35fd71e
make worker config inherit engineargs
DNXie Sep 4, 2025
187a65d
add unit test
DNXie Sep 4, 2025
0d26242
add config for grpo.main
DNXie Sep 4, 2025
ba74b43
Merge branch 'main' into add_config_rl
DNXie Sep 4, 2025
063afe6
solve conflict
DNXie Sep 4, 2025
a85f7b1
lint
DNXie Sep 4, 2025
d94d326
add vllm to unit test dep
DNXie Sep 4, 2025
5815656
solve unit test dep
DNXie Sep 4, 2025
2a31156
revert back unit_test.yaml and remove config for grpo/main
DNXie Sep 8, 2025
4778336
refactor config
DNXie Sep 8, 2025
cb42997
rename WorkerConfig to EngineConfig and all worker_params to engine_p…
DNXie Sep 8, 2025
eab380a
fix test
DNXie Sep 8, 2025
d575409
Merge branch 'main' into add_config_rl
DNXie Sep 8, 2025
a72f4de
rebase
DNXie Sep 8, 2025
b19fe24
fix lint
DNXie Sep 8, 2025
fc809f8
adding from_dict to samling overrides
DNXie Sep 8, 2025
6ca7c2b
minor.
DNXie Sep 8, 2025
f1c24fb
fix test set
DNXie Sep 8, 2025
4dc2e89
fix lint and add test for nested field
DNXie Sep 8, 2025
00c7fc9
Update src/forge/actors/policy.py
DNXie Sep 9, 2025
4445624
Update apps/vllm/main.py
DNXie Sep 9, 2025
a7dfd02
Update apps/grpo/main.py
DNXie Sep 9, 2025
1ed76c4
rename engineConfig to EngineArgOverrides
DNXie Sep 9, 2025
c38685f
remove a redundant check
DNXie Sep 9, 2025
4191fa6
fix lint
DNXie Sep 9, 2025
327828b
rename samplingoverrides to samplingconfig, engineargsoverrides to en…
DNXie Sep 9, 2025
fe9acae
rename, remove redundant logic, refactor
DNXie Sep 9, 2025
23e5ef6
fix lint
DNXie Sep 9, 2025
7b904fc
fix CI
DNXie Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import torch
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.policy import Policy, SamplingOverrides, WorkerConfig
from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
Expand Down Expand Up @@ -305,12 +305,8 @@ async def main():
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
config=PolicyConfig(
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(
num_samples=group_size, max_tokens=16
),
),
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Expand Down
19 changes: 19 additions & 0 deletions apps/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
policy:
worker_params:
model: "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: true
vllm_args: null
sampling_params:
num_samples: 2
guided_decoding: false
available_devices: null

service_config:
procs_per_replica: 2
num_replicas: 1
with_gpus: true

# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
93 changes: 22 additions & 71 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,94 +6,40 @@

"""To run:
export HF_HUB_DISABLE_XET=1
python -m apps.vllm.main --guided-decoding --num-samples 3

python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml
"""

import argparse
import asyncio
from argparse import Namespace
from typing import List
import sys

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.policy import Policy
from forge.cli.config import parse
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from vllm.outputs import CompletionOutput

from omegaconf import DictConfig
from vllm.outputs import RequestOutput

async def main():
"""Main application for running vLLM policy inference."""
args = parse_args()

# Create configuration objects
policy_config, service_config = get_configs(args)
async def run(cfg: DictConfig):

# Resolve the Prompts
if args.prompt is None:
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke"
if "prompt" in cfg and cfg["prompt"] is not None:
prompt = cfg["prompt"]
else:
prompt = args.prompt

# Run the policy
await run_vllm(service_config, policy_config, prompt)


def parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application")
parser.add_argument(
"--model",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="Model to use",
)
parser.add_argument(
"--num-samples", type=int, default=2, help="Number of samples to generate"
)
parser.add_argument(
"--guided-decoding", action="store_true", help="Enable guided decoding"
)
parser.add_argument(
"--prompt", type=str, default=None, help="Custom prompt to use for generation"
)
return parser.parse_args()


def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):
gd = cfg.policy.get("sampling_params", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"

worker_size = 2
worker_params = WorkerConfig(
model=args.model,
tensor_parallel_size=worker_size,
pipeline_parallel_size=1,
enforce_eager=True,
vllm_args=None,
)

sampling_params = SamplingOverrides(
num_samples=args.num_samples,
guided_decoding=args.guided_decoding,
)
print("Spawning service...")

policy_config = PolicyConfig(
worker_params=worker_params, sampling_params=sampling_params
)
service_config = ServiceConfig(
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
policy = await spawn_service(
ServiceConfig(**cfg.service_config), Policy, **cfg.policy
)

return policy_config, service_config


async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
print("Spawning service...")
policy = await spawn_service(service_config, Policy, config=config)

async with policy.session():
print("Requesting generation...")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)

print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(responses):
for batch, response in enumerate(response_output.outputs):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
Expand All @@ -104,5 +50,10 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
await shutdown_service(policy)


@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(run(cfg))


if __name__ == "__main__":
asyncio.run(main())
sys.exit(recipe_main())
70 changes: 46 additions & 24 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,14 @@
import logging
import os
import sys
from collections.abc import Mapping
from copy import copy
from dataclasses import asdict, dataclass, field
from typing import Dict, List

import torch

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from omegaconf import DictConfig
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM

Expand All @@ -43,6 +39,12 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig


logger = logging.getLogger(__name__)

Expand All @@ -61,7 +63,7 @@ class SamplingOverrides:
guided_decoding: Whether to use guided decoding.
"""

num_samples: int
num_samples: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to pull changes from main here, maybe this can inherit from GuidedDecoding in vLLM too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased. Will leave the inheritance to future PR.

guided_decoding: bool = False
max_tokens: int = 512

Expand All @@ -79,23 +81,27 @@ class WorkerConfig:
vllm_args: vLLM engine args.
"""

model: str
model: str = "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: int = 1
Copy link
Member Author

@DNXie DNXie Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read it here that tensor_parallel_size is under EngineConfig.parallel_config.tensor_parallel_size. If so, Is this implementation correct? Should the user pass the value like this instead:

policy:
  engine_params:
     parallel_config:
        tensor_parallel_size = 1

@pbontrager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I comment on this below, what we have is fine since parallel_config doesn't actually exist until create_engine_config is called

pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None

vllm_args: EngineArgs = field(default_factory=EngineArgs)

@dataclass
class PolicyConfig:
worker_params: WorkerConfig
sampling_params: SamplingOverrides
available_devices: str = None
@classmethod
def from_dict(cls, d: dict):
d = dict(d)
if "vllm_args" in d and isinstance(d["vllm_args"], dict):
d["vllm_args"] = EngineArgs(**d["vllm_args"])
else:
d["vllm_args"] = EngineArgs()
return cls(**d)


@dataclass
class Policy(PolicyInterface):
config: PolicyConfig
worker_params: WorkerConfig = field(default_factory=WorkerConfig)
sampling_overrides: SamplingOverrides = field(default_factory=SamplingOverrides)
available_devices: str | None = None
# Gets set up by setup
sampling_params: SamplingParams | None = None
lora_request: LoRARequest | None = None
Expand All @@ -108,13 +114,19 @@ def __post_init__(self):
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0
if isinstance(self.worker_params, Mapping):
self.worker_params = WorkerConfig.from_dict(self.worker_params)
if isinstance(self.sampling_overrides, dict):
self.sampling_overrides = SamplingOverrides(**self.sampling_overrides)

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
process_config: ProcessConfig,
config: PolicyConfig,
worker_params: WorkerConfig | dict = WorkerConfig(),
sampling_overrides: SamplingOverrides | dict = SamplingOverrides(),
available_devices: str | None = None,
store: MultiProcessStore | None = None,
**kwargs,
) -> "Policy":
Expand All @@ -128,16 +140,26 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
policy_proc_config.with_gpus = False

policy_proc = await get_proc_mesh(process_config=policy_proc_config)
workers = await worker_procs.spawn(
"vllm_worker", PolicyWorker, **asdict(config.worker_params)
)

if isinstance(worker_params, (dict, DictConfig)):
worker_params = WorkerConfig.from_dict(worker_params)

if isinstance(worker_params, (dict, DictConfig)):
sampling_overrides = SamplingOverrides(**sampling_overrides)

worker_dict = asdict(worker_params)
worker_dict["vllm_args"] = worker_params.vllm_args

workers = await worker_procs.spawn("vllm_worker", PolicyWorker, **worker_dict)

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
policy = await policy_proc.spawn(
actor_name,
cls,
config=config,
worker_params=worker_params,
sampling_overrides=sampling_overrides,
available_devices=available_devices,
policy_worker=workers,
store=store,
)
Expand Down Expand Up @@ -174,7 +196,7 @@ async def setup(self):
self.vllm_args = await self.policy_worker.get_vllm_args.choose()

# Setup sampling params
sampling_overrides = self.config.sampling_params
sampling_overrides = self.sampling_overrides
overrides = {
"n": sampling_overrides.num_samples,
"guided_decoding": (
Expand Down Expand Up @@ -371,8 +393,6 @@ def __post_init__(self):
pipeline_parallel_size=self.pipeline_parallel_size,
enforce_eager=self.enforce_eager,
)
# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
else:
# Check that provided args match Policy args
cfg = [
Expand All @@ -388,6 +408,8 @@ def __post_init__(self):
f"{key} args don't match value in EngineArgs, overriding with {value}"
)
setattr(self.vllm_args, key, value)
# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
# Build Config
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)

Expand Down
Loading