Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
da21e1d
Add reward interface, math reward, unit tests
DNXie Aug 21, 2025
5c72908
Merge branch 'meta-pytorch:main' into main
DNXie Aug 22, 2025
b4d7a61
Merge branch 'meta-pytorch:main' into main
DNXie Aug 25, 2025
02d77c6
Merge branch 'meta-pytorch:main' into main
DNXie Aug 27, 2025
fd1d38b
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
f79beee
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
d8d775a
Merge branch 'meta-pytorch:main' into main
DNXie Sep 2, 2025
7301e10
Add explicit from_dict methods for PolicyConfig and WorkerConfig
DNXie Sep 2, 2025
64687d9
remove
DNXie Sep 2, 2025
9278d75
remove policy config
DNXie Sep 3, 2025
d2d7107
update grpo.main
DNXie Sep 3, 2025
14b5e4a
fixed dict attribute error, but still a buggy version
DNXie Sep 3, 2025
38f7927
update config
DNXie Sep 3, 2025
2a1e021
for debug
DNXie Sep 3, 2025
412398c
fix the bug
DNXie Sep 3, 2025
d998061
clean up
DNXie Sep 3, 2025
8d38eb8
lint
DNXie Sep 3, 2025
a3e755d
add torchstore to dependencies
DNXie Sep 3, 2025
9dd396b
fix typo
DNXie Sep 3, 2025
935fdc1
remove a test file that causes import error
DNXie Sep 3, 2025
35fd71e
make worker config inherit engineargs
DNXie Sep 4, 2025
187a65d
add unit test
DNXie Sep 4, 2025
0d26242
add config for grpo.main
DNXie Sep 4, 2025
ba74b43
Merge branch 'main' into add_config_rl
DNXie Sep 4, 2025
063afe6
solve conflict
DNXie Sep 4, 2025
a85f7b1
lint
DNXie Sep 4, 2025
d94d326
add vllm to unit test dep
DNXie Sep 4, 2025
5815656
solve unit test dep
DNXie Sep 4, 2025
2a31156
revert back unit_test.yaml and remove config for grpo/main
DNXie Sep 8, 2025
4778336
refactor config
DNXie Sep 8, 2025
cb42997
rename WorkerConfig to EngineConfig and all worker_params to engine_p…
DNXie Sep 8, 2025
eab380a
fix test
DNXie Sep 8, 2025
d575409
Merge branch 'main' into add_config_rl
DNXie Sep 8, 2025
a72f4de
rebase
DNXie Sep 8, 2025
b19fe24
fix lint
DNXie Sep 8, 2025
fc809f8
adding from_dict to samling overrides
DNXie Sep 8, 2025
6ca7c2b
minor.
DNXie Sep 8, 2025
f1c24fb
fix test set
DNXie Sep 8, 2025
4dc2e89
fix lint and add test for nested field
DNXie Sep 8, 2025
00c7fc9
Update src/forge/actors/policy.py
DNXie Sep 9, 2025
4445624
Update apps/vllm/main.py
DNXie Sep 9, 2025
a7dfd02
Update apps/grpo/main.py
DNXie Sep 9, 2025
1ed76c4
rename engineConfig to EngineArgOverrides
DNXie Sep 9, 2025
c38685f
remove a redundant check
DNXie Sep 9, 2025
4191fa6
fix lint
DNXie Sep 9, 2025
327828b
rename samplingoverrides to samplingconfig, engineargsoverrides to en…
DNXie Sep 9, 2025
fe9acae
rename, remove redundant logic, refactor
DNXie Sep 9, 2025
23e5ef6
fix lint
DNXie Sep 9, 2025
7b904fc
fix CI
DNXie Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions apps/vllm/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
policy_config:
worker_params:
model: "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: 2
pipeline_parallel_size: 1
enforce_eager: true
vllm_args: null
sampling_params:
num_samples: 2
guided_decoding: false
available_devices: null

service_config:
procs_per_replica: 2
num_replicas: 1
with_gpus: true

# Optional, otherwise argparse fallback kicks in
prompt: "Tell me a joke"
103 changes: 35 additions & 68 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,81 +6,34 @@

"""To run:
export HF_HUB_DISABLE_XET=1
python -m apps.vllm.main --guided-decoding --num-samples 3

python -m apps.vllm.main --config apps/vllm/config.yaml
"""

import argparse
import asyncio
from argparse import Namespace
from typing import List
import sys
from typing import Any

import yaml

from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
from forge.actors.policy import Policy, PolicyConfig
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
from vllm.outputs import CompletionOutput


async def main():
"""Main application for running vLLM policy inference."""
args = parse_args()
def load_yaml_config(path: str) -> dict:
with open(path, "r") as f:
return yaml.safe_load(f)

# Create configuration objects
policy_config, service_config = get_configs(args)

# Resolve the Prompts
if args.prompt is None:
prompt = "What is 3+5?" if args.guided_decoding else "Tell me a joke"
def get_configs(cfg: dict) -> tuple[PolicyConfig, ServiceConfig, str]:
# Instantiate PolicyConfig and ServiceConfig from nested dicts
policy_config = PolicyConfig.from_dict(cfg["policy_config"])
service_config = ServiceConfig(**cfg["service_config"])
if "prompt" in cfg and cfg["prompt"] is not None:
prompt = cfg["prompt"]
else:
prompt = args.prompt

# Run the policy
await run_vllm(service_config, policy_config, prompt)


def parse_args() -> Namespace:
parser = argparse.ArgumentParser(description="VLLM Policy Inference Application")
parser.add_argument(
"--model",
type=str,
default="meta-llama/Llama-3.1-8B-Instruct",
help="Model to use",
)
parser.add_argument(
"--num-samples", type=int, default=2, help="Number of samples to generate"
)
parser.add_argument(
"--guided-decoding", action="store_true", help="Enable guided decoding"
)
parser.add_argument(
"--prompt", type=str, default=None, help="Custom prompt to use for generation"
)
return parser.parse_args()


def get_configs(args: Namespace) -> (PolicyConfig, ServiceConfig):

worker_size = 2
worker_params = WorkerConfig(
model=args.model,
tensor_parallel_size=worker_size,
pipeline_parallel_size=1,
enforce_eager=True,
vllm_args=None,
)

sampling_params = SamplingOverrides(
num_samples=args.num_samples,
guided_decoding=args.guided_decoding,
)

policy_config = PolicyConfig(
worker_params=worker_params, sampling_params=sampling_params
)
service_config = ServiceConfig(
procs_per_replica=worker_size, num_replicas=1, with_gpus=True
)

return policy_config, service_config
gd = policy_config.sampling_params.guided_decoding
prompt = "What is 3+5?" if gd else "Tell me a joke"
return policy_config, service_config, prompt


async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt: str):
Expand All @@ -89,11 +42,11 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:

async with policy.session():
print("Requesting generation...")
responses: List[CompletionOutput] = await policy.generate.choose(prompt=prompt)
response_output = await policy.generate.choose(prompt=prompt)

print("\nGeneration Results:")
print("=" * 80)
for batch, response in enumerate(responses):
for batch, response in enumerate(response_output.outputs):
print(f"Sample {batch + 1}:")
print(f"User: {prompt}")
print(f"Assistant: {response.text}")
Expand All @@ -104,5 +57,19 @@ async def run_vllm(service_config: ServiceConfig, config: PolicyConfig, prompt:
await shutdown_service(policy)


def main():
import argparse

parser = argparse.ArgumentParser(description="vLLM Policy Inference Application")
parser.add_argument(
"--config", type=str, required=True, help="Path to YAML config file"
)
args = parser.parse_args()

cfg = load_yaml_config(args.config)
policy_config, service_config, prompt = get_configs(cfg)
asyncio.run(run_vllm(service_config, policy_config, prompt))


if __name__ == "__main__":
asyncio.run(main())
sys.exit(main())
28 changes: 23 additions & 5 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ class SamplingOverrides:
guided_decoding: Whether to use guided decoding.
"""

num_samples: int
num_samples: int = 1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need to pull changes from main here, maybe this can inherit from GuidedDecoding in vLLM too

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rebased. Will leave the inheritance to future PR.

guided_decoding: bool = False
max_tokens: int = 512

Expand All @@ -79,19 +79,35 @@ class WorkerConfig:
vllm_args: vLLM engine args.
"""

model: str
model: str = "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: int = 1
Copy link
Member Author

@DNXie DNXie Sep 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read it here that tensor_parallel_size is under EngineConfig.parallel_config.tensor_parallel_size. If so, Is this implementation correct? Should the user pass the value like this instead:

policy:
  engine_params:
     parallel_config:
        tensor_parallel_size = 1

@pbontrager

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I comment on this below, what we have is fine since parallel_config doesn't actually exist until create_engine_config is called

pipeline_parallel_size: int = 1
enforce_eager: bool = False
vllm_args: EngineArgs = None
vllm_args: EngineArgs = field(default_factory=EngineArgs)

@classmethod
def from_dict(cls, d: dict):
d = dict(d) # copy
if "vllm_args" in d and isinstance(d["vllm_args"], dict):
d["vllm_args"] = EngineArgs(**d["vllm_args"])
return cls(**d)


@dataclass
class PolicyConfig:
worker_params: WorkerConfig
sampling_params: SamplingOverrides
worker_params: WorkerConfig = field(default_factory=WorkerConfig)
sampling_params: SamplingOverrides = field(default_factory=SamplingOverrides)
available_devices: str = None

@classmethod
def from_dict(cls, d: dict):
d = dict(d)
if "worker_params" in d and isinstance(d["worker_params"], dict):
d["worker_params"] = WorkerConfig.from_dict(d["worker_params"])
if "sampling_params" in d and isinstance(d["sampling_params"], dict):
d["sampling_params"] = SamplingOverrides(**d["sampling_params"])
return cls(**d)


@dataclass
class Policy(PolicyInterface):
Expand All @@ -108,6 +124,8 @@ def __post_init__(self):
self._policy_proc: ProcMesh | None = None
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0
if isinstance(self.config, dict):
self.config = PolicyConfig.from_dict(self.config)

@classmethod
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
Expand Down
Loading