Skip to content
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
da21e1d
Add reward interface, math reward, unit tests
DNXie Aug 21, 2025
5c72908
Merge branch 'meta-pytorch:main' into main
DNXie Aug 22, 2025
b4d7a61
Merge branch 'meta-pytorch:main' into main
DNXie Aug 25, 2025
02d77c6
Merge branch 'meta-pytorch:main' into main
DNXie Aug 27, 2025
fd1d38b
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
f79beee
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
d8d775a
Merge branch 'meta-pytorch:main' into main
DNXie Sep 2, 2025
7301e10
Add explicit from_dict methods for PolicyConfig and WorkerConfig
DNXie Sep 2, 2025
64687d9
remove
DNXie Sep 2, 2025
9278d75
remove policy config
DNXie Sep 3, 2025
d2d7107
update grpo.main
DNXie Sep 3, 2025
14b5e4a
fixed dict attribute error, but still a buggy version
DNXie Sep 3, 2025
38f7927
update config
DNXie Sep 3, 2025
2a1e021
for debug
DNXie Sep 3, 2025
412398c
fix the bug
DNXie Sep 3, 2025
d998061
clean up
DNXie Sep 3, 2025
8d38eb8
lint
DNXie Sep 3, 2025
a3e755d
add torchstore to dependencies
DNXie Sep 3, 2025
9dd396b
fix typo
DNXie Sep 3, 2025
935fdc1
remove a test file that causes import error
DNXie Sep 3, 2025
35fd71e
make worker config inherit engineargs
DNXie Sep 4, 2025
187a65d
add unit test
DNXie Sep 4, 2025
0d26242
add config for grpo.main
DNXie Sep 4, 2025
ba74b43
Merge branch 'main' into add_config_rl
DNXie Sep 4, 2025
063afe6
solve conflict
DNXie Sep 4, 2025
a85f7b1
lint
DNXie Sep 4, 2025
d94d326
add vllm to unit test dep
DNXie Sep 4, 2025
5815656
solve unit test dep
DNXie Sep 4, 2025
2a31156
revert back unit_test.yaml and remove config for grpo/main
DNXie Sep 8, 2025
4778336
refactor config
DNXie Sep 8, 2025
cb42997
rename WorkerConfig to EngineConfig and all worker_params to engine_p…
DNXie Sep 8, 2025
eab380a
fix test
DNXie Sep 8, 2025
d575409
Merge branch 'main' into add_config_rl
DNXie Sep 8, 2025
a72f4de
rebase
DNXie Sep 8, 2025
b19fe24
fix lint
DNXie Sep 8, 2025
fc809f8
adding from_dict to samling overrides
DNXie Sep 8, 2025
6ca7c2b
minor.
DNXie Sep 8, 2025
f1c24fb
fix test set
DNXie Sep 8, 2025
4dc2e89
fix lint and add test for nested field
DNXie Sep 8, 2025
00c7fc9
Update src/forge/actors/policy.py
DNXie Sep 9, 2025
4445624
Update apps/vllm/main.py
DNXie Sep 9, 2025
a7dfd02
Update apps/grpo/main.py
DNXie Sep 9, 2025
1ed76c4
rename engineConfig to EngineArgOverrides
DNXie Sep 9, 2025
c38685f
remove a redundant check
DNXie Sep 9, 2025
4191fa6
fix lint
DNXie Sep 9, 2025
327828b
rename samplingoverrides to samplingconfig, engineargsoverrides to en…
DNXie Sep 9, 2025
fe9acae
rename, remove redundant logic, refactor
DNXie Sep 9, 2025
23e5ef6
fix lint
DNXie Sep 9, 2025
7b904fc
fix CI
DNXie Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.nn.functional as F
from datasets import load_dataset
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.policy import EngineConfig, Policy, SamplingOverrides
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
Expand Down Expand Up @@ -362,12 +362,8 @@ async def main():
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
config=PolicyConfig(
worker_params=WorkerConfig(model=model),
sampling_params=SamplingOverrides(
n=group_size, max_tokens=max_res_tokens
),
),
engine_params=EngineConfig(model=model),
sampling_overrides=SamplingOverrides(n=group_size, max_tokens=16),
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Expand Down
19 changes: 19 additions & 0 deletions apps/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
policy:
engine_params:
model: "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: true
sampling_overrides:
n: 2
guided_decoding: false
max_tokens: 512
available_devices: null
service:
procs_per_replica: 2
num_replicas: 1
with_gpus: true
Comment on lines +12 to +15
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of scope for this PR, but we should think about the "service in yaml pattern" when we have some breathing room

We're gonna have a pattern of excluding this field when passings args around (since X.service is not a common Agent Arg)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you be more clear with the suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally think we should have spawn_service handle this to make it less awkward but we can do that later.

Something like

await spawn_service(Policy, **cfg.policy)

where spawn_service(actor: Actor, service_config: ServiceConfig | Mapping, **kwargs)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, no action required here

Seconding the API too



# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
96 changes: 20 additions & 76 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,95 +6,34 @@

"""To run:
export HF_HUB_DISABLE_XET=1
python -m apps.vllm.main --guided-decoding --num-samples 3

python -m apps.vllm.main --config apps/vllm/llama3_8b.yaml
"""

import argparse
import asyncio
from argparse import Namespace
import sys

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.policy import Policy
from forge.cli.config import parse
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from vllm.outputs import RequestOutput
from vllm.transformers_utils.tokenizer import get_tokenizer

from omegaconf import DictConfig
from vllm.outputs import RequestOutput

async def main():
"""Main application for running vLLM policy inference."""
args = parse_args()

# Create configuration objects
policy_config, service_config = get_configs(args)
async def run(cfg: DictConfig):

# Resolve the Prompts
if args.prompt is None:
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke"
if "prompt" in cfg and cfg["prompt"] is not None:
prompt = cfg["prompt"]
else:
prompt = args.prompt

# format prompt
tokenizer = get_tokenizer(policy_config.worker_params.model)
messages = [{"role": "user", "content": prompt}]
prompt = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)

# Run the policy
await run_vllm(service_config, policy_config, prompt)


def parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application")
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen3-1.7B", # "meta-llama/Llama-3.1-8B-Instruct",
help="Model to use",
)
parser.add_argument(
"--num-samples", type=int, default=2, help="Number of samples to generate"
)
parser.add_argument(
"--guided-decoding", action="store_true", help="Enable guided decoding"
)
parser.add_argument(
"--prompt", type=str, default=None, help="Custom prompt to use for generation"
)
return parser.parse_args()

gd = cfg.policy.get("sampling_overrides", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"

def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):

worker_size = 2
worker_params = WorkerConfig(
model=args.model,
tensor_parallel_size=worker_size,
pipeline_parallel_size=1,
enforce_eager=True,
vllm_args=None,
)

sampling_params = SamplingOverrides(
n=args.num_samples,
guided_decoding=args.guided_decoding,
max_tokens=16,
)
print("Spawning service...")

policy_config = PolicyConfig(
worker_params=worker_params, sampling_params=sampling_params
)
service_config = ServiceConfig(
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
policy = await spawn_service(
ServiceConfig(**cfg.policy.service), Policy, **cfg.policy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems funky

Does Policy need the service configs args?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They all do, every service needs it's own config for the resources it'll get. See previous comment for how this can be made smoother

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be missing something, but where does Policy use cfg.policy.service

)

return policy_config, service_config


async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
print("Spawning service...")
policy = await spawn_service(service_config, Policy, config=config)

async with policy.session():
print("Requesting generation...")
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)
Expand All @@ -112,5 +51,10 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
await shutdown_service(policy)


@parse
def recipe_main(cfg: DictConfig) -> None:
asyncio.run(run(cfg))


if __name__ == "__main__":
asyncio.run(main())
sys.exit(recipe_main())
121 changes: 61 additions & 60 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,18 @@
import logging
import os
import sys
from collections.abc import Mapping
from copy import copy
from dataclasses import asdict, dataclass, field
from typing import Dict, List
from typing import Any, Dict, List

import torch

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig
from monarch.actor import current_rank, endpoint, ProcMesh
from torchstore import MultiProcessStore
from torchstore._state_dict_utils import DELIM
Expand All @@ -37,12 +44,6 @@
from vllm.v1.structured_output import StructuredOutputManager
from vllm.worker.worker_base import WorkerWrapperBase

from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh

from forge.data.sharding import VLLMSharding
from forge.interfaces import Policy as PolicyInterface
from forge.types import ProcessConfig


logger = logging.getLogger(__name__)

Expand All @@ -62,7 +63,7 @@ class SamplingOverrides:
max_tokens: Maximum number of tokens to generate.
"""

n: int
n: int = 1
guided_decoding: bool = False
max_tokens: int = 512

Expand All @@ -72,37 +73,39 @@ def __post_init__(self):
gd_params = GuidedDecodingParams(choice=["Positive", "Negative"])
self.guided_decoding = gd_params

@classmethod
def from_dict(cls, d: Mapping):
d = dict(d)
all_fields = set(cls.__dataclass_fields__.keys())
valid_args = {k: v for k, v in d.items() if k in all_fields}
return cls(**valid_args)


@dataclass
class WorkerConfig:
class EngineConfig(EngineArgs):
"""
Config args used for setting up the policy worker.

Args:
model: Model name.
tensor_parallel_size: Number of tensor parallel workers.
pipeline_parallel_size: Number of pipeline parallel workers.
enforce_eager: Whether to enforce eager mode.
vllm_args: vLLM engine args.
EngineConfig extends EngineArgs with worker-specific fields.
Overlapping keys in input dict will override EngineArgs defaults.
Comment on lines +87 to +88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
EngineConfig extends EngineArgs with worker-specific fields.
Overlapping keys in input dict will override EngineArgs defaults.
EngineConfig extends EngineArgs surfacing worker-specific fields.
Args of this class override EngineArgs defaults.

"""

model: str
model: str = "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: int = 1
Copy link
Member Author

@DNXie DNXie Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read it here that tensor_parallel_size is under EngineConfig.parallel_config.tensor_parallel_size. If so, Is this implementation correct? Should the user pass the value like this instead:

policy:
  engine_params:
     parallel_config:
        tensor_parallel_size = 1

@pbontrager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I comment on this below, what we have is fine since parallel_config doesn't actually exist until create_engine_config is called

pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None


@dataclass
class PolicyConfig:
worker_params: WorkerConfig
sampling_params: SamplingOverrides
available_devices: str = None
@classmethod
def from_dict(cls, d: Mapping):
d = dict(d)
all_fields = set(cls.__dataclass_fields__.keys())
valid_args = {k: v for k, v in d.items() if k in all_fields}
return cls(**valid_args)


@dataclass
class Policy(PolicyInterface):
config: PolicyConfig
engine_params: EngineConfig | Mapping = field(default_factory=EngineConfig)
sampling_overrides: SamplingOverrides = field(default_factory=SamplingOverrides)
available_devices: str | None = None
# Gets set up by setup
sampling_params: SamplingParams | None = None
lora_request: LoRARequest | None = None
Expand All @@ -115,13 +118,21 @@ def __post_init__(self):
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0
if isinstance(self.engine_params, Mapping):
self.engine_params = EngineConfig.from_dict(self.engine_params)
if isinstance(self.sampling_overrides, Mapping):
self.sampling_overrides = SamplingOverrides.from_dict(
self.sampling_overrides
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we allowing Mapping as an input type just to work around the yaml?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Any suggestions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to change anything here, but worth us thinking about down the line if we should shim this out across the repo (abstration that handles all the class constructions, actors can act on pure python) so that the actor logic is simpler

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds reasonable. I agree. Let's not include it in this PR for now.


@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
process_config: ProcessConfig,
config: PolicyConfig,
engine_params: EngineConfig | Mapping = EngineConfig(),
sampling_overrides: SamplingOverrides | Mapping = SamplingOverrides(),
available_devices: str | None = None,
store: MultiProcessStore | None = None,
**kwargs,
) -> "Policy":
Expand All @@ -135,16 +146,25 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
policy_proc_config.with_gpus = False

policy_proc = await get_proc_mesh(process_config=policy_proc_config)

if isinstance(engine_params, Mapping):
engine_params = EngineConfig.from_dict(engine_params)

if isinstance(engine_params, Mapping):
sampling_overrides = SamplingOverrides(**sampling_overrides)

workers = await worker_procs.spawn(
"vllm_worker", PolicyWorker, **asdict(config.worker_params)
"vllm_worker", PolicyWorker, vllm_args=engine_params
)

# TODO - expand support so name can stick within kwargs
actor_name = kwargs.pop("name", cls.__name__)
policy = await policy_proc.spawn(
actor_name,
cls,
config=config,
engine_params=engine_params,
sampling_overrides=sampling_overrides,
available_devices=available_devices,
policy_worker=workers,
store=store,
)
Expand Down Expand Up @@ -182,7 +202,7 @@ async def setup(self):

# Setup sampling params
self.sampling_params = get_default_sampling_params(
self.vllm_args, overrides=asdict(self.config.sampling_params)
self.vllm_args, overrides=asdict(self.sampling_overrides)
)

# Setup processors
Expand Down Expand Up @@ -348,11 +368,7 @@ async def stop(self):

@dataclass
class PolicyWorker(ForgeActor):
model: str
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None
vllm_args: EngineConfig | dict = EngineConfig()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So there is an typing/design quirk that is throughout the repo, that I'm not a huge fan of, but it's out of scope for this PR

The types used during construction isn't necessarily the same type as the field after construction. In this case vllm_args as input can be a dict | EngineArgs but dict gets converted to EngineArgs (to make calling easier with the yaml?), and in post__init it further gets transformed into a vLLMConfig

state_dict_key: str = "model_state_dict"

def __post_init__(self):
Expand All @@ -368,31 +384,14 @@ def __post_init__(self):
- all LLM generate methods, verify against LLM inputs
- all executor methods verify no changes
"""
if self.vllm_args is None:
# Use default vllm EngineArgs
self.vllm_args = EngineArgs(
model=self.model,
tensor_parallel_size=self.tensor_parallel_size,
pipeline_parallel_size=self.pipeline_parallel_size,
enforce_eager=self.enforce_eager,
if isinstance(self.vllm_args, dict):
self.vllm_args = EngineConfig.from_dict(self.vllm_args)
elif not isinstance(self.vllm_args, EngineConfig):
raise TypeError(
f"vllm_args must be a EngineConfig or dict, got {type(self.vllm_args)}"
)
# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
else:
# Check that provided args match Policy args
cfg = [
"model",
"tensor_parallel_size",
"pipeline_parallel_size",
"data_parallel_size",
]
for key in cfg:
value = getattr(self, key) if key != "data_parallel_size" else 1
if getattr(self.vllm_args, key) != value:
logger.warning(
f"{key} args don't match value in EngineArgs, overriding with {value}"
)
setattr(self.vllm_args, key, value)
# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
# Build Config
self.vllm_args = self.vllm_args.create_engine_config(UsageContext.LLM_CLASS)

Expand All @@ -416,7 +415,9 @@ async def _load_tensor_parallel_state_dict(

updated_count = 0
# setting explictly to llama3 for now as its our only use case
sharding = VLLMSharding(self.tensor_parallel_size, self.rank)
sharding = VLLMSharding(
self.vllm_args.parallel_config.tensor_parallel_size, self.rank
)

for param_name in current_state_dict.keys():
current_tensor = current_state_dict[param_name]
Expand Down
Loading
Loading