Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
da21e1d
Add reward interface, math reward, unit tests
DNXie Aug 21, 2025
5c72908
Merge branch 'meta-pytorch:main' into main
DNXie Aug 22, 2025
b4d7a61
Merge branch 'meta-pytorch:main' into main
DNXie Aug 25, 2025
02d77c6
Merge branch 'meta-pytorch:main' into main
DNXie Aug 27, 2025
fd1d38b
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
f79beee
Merge branch 'meta-pytorch:main' into main
DNXie Aug 28, 2025
d8d775a
Merge branch 'meta-pytorch:main' into main
DNXie Sep 2, 2025
7301e10
Add explicit from_dict methods for PolicyConfig and WorkerConfig
DNXie Sep 2, 2025
64687d9
remove
DNXie Sep 2, 2025
9278d75
remove policy config
DNXie Sep 3, 2025
d2d7107
update grpo.main
DNXie Sep 3, 2025
14b5e4a
fixed dict attribute error, but still a buggy version
DNXie Sep 3, 2025
38f7927
update config
DNXie Sep 3, 2025
2a1e021
for debug
DNXie Sep 3, 2025
412398c
fix the bug
DNXie Sep 3, 2025
d998061
clean up
DNXie Sep 3, 2025
8d38eb8
lint
DNXie Sep 3, 2025
a3e755d
add torchstore to dependencies
DNXie Sep 3, 2025
9dd396b
fix typo
DNXie Sep 3, 2025
935fdc1
remove a test file that causes import error
DNXie Sep 3, 2025
35fd71e
make worker config inherit engineargs
DNXie Sep 4, 2025
187a65d
add unit test
DNXie Sep 4, 2025
0d26242
add config for grpo.main
DNXie Sep 4, 2025
ba74b43
Merge branch 'main' into add_config_rl
DNXie Sep 4, 2025
063afe6
solve conflict
DNXie Sep 4, 2025
a85f7b1
lint
DNXie Sep 4, 2025
d94d326
add vllm to unit test dep
DNXie Sep 4, 2025
5815656
solve unit test dep
DNXie Sep 4, 2025
2a31156
revert back unit_test.yaml and remove config for grpo/main
DNXie Sep 8, 2025
4778336
refactor config
DNXie Sep 8, 2025
cb42997
rename WorkerConfig to EngineConfig and all worker_params to engine_p…
DNXie Sep 8, 2025
eab380a
fix test
DNXie Sep 8, 2025
d575409
Merge branch 'main' into add_config_rl
DNXie Sep 8, 2025
a72f4de
rebase
DNXie Sep 8, 2025
b19fe24
fix lint
DNXie Sep 8, 2025
fc809f8
adding from_dict to samling overrides
DNXie Sep 8, 2025
6ca7c2b
minor.
DNXie Sep 8, 2025
f1c24fb
fix test set
DNXie Sep 8, 2025
4dc2e89
fix lint and add test for nested field
DNXie Sep 8, 2025
00c7fc9
Update src/forge/actors/policy.py
DNXie Sep 9, 2025
4445624
Update apps/vllm/main.py
DNXie Sep 9, 2025
a7dfd02
Update apps/grpo/main.py
DNXie Sep 9, 2025
1ed76c4
rename engineConfig to EngineArgOverrides
DNXie Sep 9, 2025
c38685f
remove a redundant check
DNXie Sep 9, 2025
4191fa6
fix lint
DNXie Sep 9, 2025
327828b
rename samplingoverrides to samplingconfig, engineargsoverrides to en…
DNXie Sep 9, 2025
fe9acae
rename, remove redundant logic, refactor
DNXie Sep 9, 2025
23e5ef6
fix lint
DNXie Sep 9, 2025
7b904fc
fix CI
DNXie Sep 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import torch
import torch.nn.functional as F
from datasets import load_dataset
from forge.actors.policy import EngineConfig, Policy, SamplingOverrides
from forge.actors.policy import EngineArgOverrides, Policy, SamplingOverrides
from forge.actors.replay_buffer import ReplayBuffer
from forge.controller.actor import ForgeActor
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
Expand Down Expand Up @@ -362,8 +362,10 @@ async def main():
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Policy,
engine_params=EngineConfig(model=model),
sampling_overrides=SamplingOverrides(n=group_size, max_tokens=16),
engine_params=EngineArgOverrides(model=model),
sampling_overrides=SamplingOverrides(
n=group_size, max_tokens=max_res_tokens
),
),
spawn_service(
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
Expand Down
4 changes: 3 additions & 1 deletion apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@

async def run(cfg: DictConfig):

if "prompt" in cfg and cfg["prompt"] is not None:
if (prompt := cfg.get("prompt")) is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like there's duplicate logic here

gd = cfg.policy.get("sampling_overrides", {}).get("guided_decoding", False)
prompt = "What is 3+5?" if gd else "Tell me a joke"
prompt = cfg["prompt"]
else:
gd = cfg.policy.get("sampling_overrides", {}).get("guided_decoding", False)
Expand Down
29 changes: 15 additions & 14 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,9 @@ def from_dict(cls, d: Mapping):


@dataclass
class EngineConfig(EngineArgs):
class EngineArgOverrides(EngineArgs):
"""
EngineConfig extends EngineArgs with worker-specific fields.
EngineArgOverrides extends EngineArgs with worker-specific fields.
Overlapping keys in input dict will override EngineArgs defaults.
"""

Expand All @@ -103,8 +103,12 @@ def from_dict(cls, d: Mapping):

@dataclass
class Policy(PolicyInterface):
engine_params: EngineConfig | Mapping = field(default_factory=EngineConfig)
sampling_overrides: SamplingOverrides = field(default_factory=SamplingOverrides)
engine_params: EngineArgOverrides | Mapping = field(
default_factory=EngineArgOverrides
)
sampling_overrides: SamplingOverrides | Mapping = field(
default_factory=SamplingOverrides
)
available_devices: str | None = None
# Gets set up by setup
sampling_params: SamplingParams | None = None
Expand All @@ -119,7 +123,7 @@ def __post_init__(self):
self._worker_procs: ProcMesh | None = None
self.weights_version: int = 0
if isinstance(self.engine_params, Mapping):
self.engine_params = EngineConfig.from_dict(self.engine_params)
self.engine_params = EngineArgOverrides.from_dict(self.engine_params)
if isinstance(self.sampling_overrides, Mapping):
self.sampling_overrides = SamplingOverrides.from_dict(
self.sampling_overrides
Expand All @@ -130,7 +134,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
process_config: ProcessConfig,
engine_params: EngineConfig | Mapping = EngineConfig(),
engine_params: EngineArgOverrides | Mapping = EngineArgOverrides(),
sampling_overrides: SamplingOverrides | Mapping = SamplingOverrides(),
available_devices: str | None = None,
store: MultiProcessStore | None = None,
Expand All @@ -148,7 +152,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
policy_proc = await get_proc_mesh(process_config=policy_proc_config)

if isinstance(engine_params, Mapping):
engine_params = EngineConfig.from_dict(engine_params)
engine_params = EngineArgOverrides.from_dict(engine_params)

if isinstance(engine_params, Mapping):
sampling_overrides = SamplingOverrides(**sampling_overrides)
Expand Down Expand Up @@ -368,7 +372,7 @@ async def stop(self):

@dataclass
class PolicyWorker(ForgeActor):
vllm_args: EngineConfig | dict = EngineConfig()
vllm_args: EngineArgOverrides | Mapping = EngineArgOverrides()
state_dict_key: str = "model_state_dict"

def __post_init__(self):
Expand All @@ -384,12 +388,9 @@ def __post_init__(self):
- all LLM generate methods, verify against LLM inputs
- all executor methods verify no changes
"""
if isinstance(self.vllm_args, dict):
self.vllm_args = EngineConfig.from_dict(self.vllm_args)
elif not isinstance(self.vllm_args, EngineConfig):
raise TypeError(
f"vllm_args must be a EngineConfig or dict, got {type(self.vllm_args)}"
)
if isinstance(self.vllm_args, Mapping):
self.vllm_args = EngineArgOverrides.from_dict(self.vllm_args)

# Original method returns False when not run in the main thread
self.vllm_args._is_v1_supported_oracle = lambda *_: True
# Build Config
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import torch

from forge.actors.policy import EngineConfig, Policy, SamplingOverrides
from forge.actors.policy import EngineArgOverrides, Policy, SamplingOverrides
from forge.controller.service import ServiceConfig, spawn_service
from forge.data.sharding import VLLMSharding
from torchstore import MultiProcessStore
Expand Down Expand Up @@ -168,7 +168,7 @@ def validate_loaded_tensors_equals_original(

def get_configs(worker_size: int, model_name: str) -> Tuple[Dict, ServiceConfig]:

engine_params = EngineConfig(
engine_params = EngineArgOverrides(
model=model_name,
tensor_parallel_size=worker_size,
pipeline_parallel_size=1,
Expand Down
12 changes: 6 additions & 6 deletions tests/unit_tests/test_policy_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import yaml

from forge.actors.policy import EngineConfig, Policy, SamplingOverrides
from forge.actors.policy import EngineArgOverrides, Policy, SamplingOverrides


class TestPolicyConfig(unittest.TestCase):
Expand All @@ -20,7 +20,7 @@ def test_policy_default_initialization(self):
policy = Policy()

# Default factories
self.assertIsInstance(policy.engine_params, EngineConfig)
self.assertIsInstance(policy.engine_params, EngineArgOverrides)
self.assertIsInstance(policy.sampling_overrides, SamplingOverrides)
self.assertIsNone(policy.available_devices)

Expand Down Expand Up @@ -62,7 +62,7 @@ def test_policy_with_dict_configs(self):
available_devices="test-gpu-device-abcd",
)

self.assertIsInstance(policy.engine_params, EngineConfig)
self.assertIsInstance(policy.engine_params, EngineArgOverrides)
self.assertIsInstance(policy.sampling_overrides, SamplingOverrides)

# Test basic fields
Expand Down Expand Up @@ -124,15 +124,15 @@ def test_policy_yaml_config_loading(self):

self.assertEqual(policy.available_devices, "yaml-test-device-xyz")

def test_engineconfig_ignores_invalid_keys(self):
"""EngineConfig.from_dict ignores unexpected keys."""
def test_engineargoverrides_ignores_invalid_keys(self):
"""EngineArgOverrides.from_dict ignores unexpected keys."""
engine_params = {
"model": "custom-model",
"tensor_parallel_size": 2,
"invalid_key_123": "should be ignored",
}

config = EngineConfig.from_dict(engine_params)
config = EngineArgOverrides.from_dict(engine_params)

self.assertEqual(config.model, "custom-model")
self.assertEqual(config.tensor_parallel_size, 2)
Expand Down
Loading