Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
Expand Down Expand Up @@ -115,7 +115,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true
ref_model:
Expand Down
4 changes: 2 additions & 2 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 4
pipeline_parallel_size: 1
Expand Down Expand Up @@ -118,7 +118,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
hosts: 1
with_gpus: true
Expand Down
4 changes: 2 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
Expand Down Expand Up @@ -114,7 +114,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true
ref_model:
Expand Down
4 changes: 2 additions & 2 deletions apps/mast/qwen3_14b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-14B/snapshots/8268fe3026cb304910457689366670e803a6fd56
tensor_parallel_size: 2
pipeline_parallel_size: 1
Expand Down Expand Up @@ -129,7 +129,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 2
with_gpus: true
mesh_name: policy
Expand Down
4 changes: 2 additions & 2 deletions apps/mast/qwen3_1_7b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-1.7B/snapshots/0060bc56d46589041c1048efd1a397421b1142b5
tensor_parallel_size: 1
pipeline_parallel_size: 1
Expand Down Expand Up @@ -125,7 +125,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 2
with_gpus: true
mesh_name: policy
Expand Down
4 changes: 2 additions & 2 deletions apps/mast/qwen3_32b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-32B/snapshots/d47b0d4ae4b48fde975756bf360a63a9cca8d470
tensor_parallel_size: 2
pipeline_parallel_size: 1
Expand Down Expand Up @@ -128,7 +128,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 2
with_gpus: true
mesh_name: policy
Expand Down
4 changes: 2 additions & 2 deletions apps/mast/qwen3_4b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-4B-Base/snapshots/a81b894c2624d21c88a3ad737ce4f837424b7eed
tensor_parallel_size: 2
pipeline_parallel_size: 1
Expand Down Expand Up @@ -125,7 +125,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 2
with_gpus: true
mesh_name: policy
Expand Down
4 changes: 2 additions & 2 deletions apps/mast/qwen3_8b_mast.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args: # https://docs.vllm.ai/en/v0.10.0/api/vllm/engine/arg_utils.html#vllm.engine.arg_utils.EngineArgs
model: /mnt/wsfuse/huggingface_models/models--Qwen--Qwen3-8B/snapshots/model
tensor_parallel_size: 2
pipeline_parallel_size: 1
Expand Down Expand Up @@ -125,7 +125,7 @@ ref_model:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 2
with_gpus: true
mesh_name: policy
Expand Down
55 changes: 15 additions & 40 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import sys
from collections.abc import Mapping
from copy import copy
from dataclasses import dataclass, field, fields
from dataclasses import dataclass, field

import torch
import torch.distributed.checkpoint as dcp
Expand Down Expand Up @@ -62,39 +62,9 @@
logger.setLevel(logging.INFO)


@dataclass
class EngineConfig(EngineArgs):
"""
EngineConfig extends EngineArgs with worker-specific fields.
Overlapping keys in input dict will override EngineArgs defaults.
"""

model: str = "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: int = 1
pipeline_parallel_size: int = 1
enforce_eager: bool = False
enable_expert_parallel: bool = False

# Original method returns False when not run in the main thread
_is_v1_supported_oracle = lambda *_: True

@classmethod
def from_dict(cls, d: Mapping):
d = dict(d)
all_fields = [f.name for f in fields(cls)]
valid_args = {k: v for k, v in d.items() if k in all_fields}
return cls(**valid_args)

def create_vllm_config(self) -> VllmConfig:
"""Converts the current EngineConfig into vLLM's vLLMConfig."""
# Note: EngineArgs.create_engine_config
# creates a VllmConfig
return self.create_engine_config(UsageContext.LLM_CLASS)


@dataclass
class Policy(PolicyInterface):
engine_config: EngineConfig | Mapping = field(default_factory=EngineConfig)
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
available_devices: str | None = None
use_dcp: bool = True
Expand All @@ -111,8 +81,9 @@ def __post_init__(self):
self._worker_procs: ProcMesh | None = None
self.running = False

if isinstance(self.engine_config, Mapping):
self.engine_config = EngineConfig.from_dict(self.engine_config)
if isinstance(self.engine_args, Mapping):
self.engine_args = EngineArgs(**self.engine_args)
self.engine_args._is_v1_supported_oracle = lambda *_: True

if isinstance(self.sampling_params, Mapping):
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)
Expand All @@ -122,7 +93,7 @@ def __post_init__(self):
async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
cls: type["Policy"],
*,
engine_config: EngineConfig | Mapping = EngineConfig(),
engine_args: EngineArgs | Mapping = EngineArgs(),
sampling_params: SamplingParams | Mapping = SamplingParams(),
available_devices: str | None = None,
use_dcp: bool = True,
Expand Down Expand Up @@ -150,10 +121,12 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
policy_proc_config.with_gpus = False
policy_proc = await get_proc_mesh(process_config=policy_proc_config)

if isinstance(engine_config, Mapping):
engine_config = EngineConfig.from_dict(engine_config)
if isinstance(engine_args, Mapping):
engine_args = EngineArgs(**engine_args)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same flag as the other PR, when folks want to use the richer nested fields in EngineArgs we'll need to support it

engine_args._is_v1_supported_oracle = lambda *_: True # Always default on
logger.debug(f"Resolved engine args: {engine_args}")

vllm_config = engine_config.create_vllm_config()
vllm_config = engine_args.create_engine_config(UsageContext.LLM_CLASS)
workers = worker_procs.spawn(
"vllm_worker", PolicyWorker, vllm_config=vllm_config, use_dcp=use_dcp
)
Expand All @@ -168,7 +141,7 @@ async def launch( # pyright: ignore[reportIncompatibleMethodOverride]
policy = policy_proc.spawn(
actor_name,
cls,
engine_config=engine_config,
engine_args=engine_args,
sampling_params=sampling_params,
available_devices=available_devices,
policy_worker=workers,
Expand Down Expand Up @@ -214,7 +187,9 @@ async def setup(self):
# Guard for updating requests
self.update_lock = asyncio.Condition()

self.vllm_config: VllmConfig = self.engine_config.create_vllm_config()
self.vllm_config: VllmConfig = self.engine_args.create_engine_config(
UsageContext.LLM_CLASS
)

# Setup processors
# TODO: move all processing to the Environment
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/fixtures/qwen3_1_7b_no_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ off_by_n: 1 # Off by one by default

# Policy configuration
policy:
engine_config:
engine_args:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
Expand Down Expand Up @@ -63,7 +63,7 @@ trainer:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true
trainer:
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/fixtures/qwen3_1_7b_tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ off_by_n: 1 # Off by one by default

# Policy configuration
policy:
engine_config:
engine_args:
model: ${model}
tensor_parallel_size: 4
pipeline_parallel_size: 1
Expand Down Expand Up @@ -65,7 +65,7 @@ trainer:
# All resource allocations
services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
with_gpus: true
trainer:
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/test_policy_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ async def test_sanity_check(self, request):
cfg = self._load_config(config_path=config_path)

trainer_proc_size = cfg.actors.trainer.procs
policy_tp_size = cfg.policy.engine_config.tensor_parallel_size
policy_tp_size = cfg.policy.engine_args.tensor_parallel_size

if policy_tp_size != cfg.services.policy.procs:
pytest.fail(
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/test_vllm_policy_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def test_same_output():
policy = await Policy.options(
procs=1, num_replicas=1, with_gpus=True
).as_service(
engine_config={
engine_args={
"model": MODEL_NAME,
"tensor_parallel_size": TENSOR_PARALLEL_SIZE,
"enforce_eager": ENFORCE_EAGER,
Expand Down Expand Up @@ -143,7 +143,7 @@ async def test_cache_usage():
policy = await Policy.options(
procs=1, num_replicas=1, with_gpus=True
).as_service(
engine_config={
engine_args={
"model": MODEL_NAME,
"tensor_parallel_size": TENSOR_PARALLEL_SIZE,
"enforce_eager": ENFORCE_EAGER,
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/toy_rl/sumdigits-tp.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dataset:

# Policy configuration
policy:
engine_config:
engine_args:
model: ${model}
tensor_parallel_size: 2
pipeline_parallel_size: 1
Expand Down
2 changes: 0 additions & 2 deletions tests/sandbox/toy_rl/sumdigits.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,6 @@ async def main(cfg: DictConfig):
group_size = cfg.group_size
max_req_tokens = cfg.max_req_tokens
max_res_tokens = cfg.max_res_tokens
# TODO: delete this logic after we are confident on the vllm weight sync long term fix PR #184
policy_tp_size = cfg.policy.engine_config.tensor_parallel_size

# ---- Setup services ---- #
print(f"{cfg.policy=}")
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/toy_rl/sumdigits.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ dataset:
# Policy configuration
policy:
use_dcp: false
engine_config:
engine_args:
model: ${model}
tensor_parallel_size: 1
pipeline_parallel_size: 1
Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/vllm/deepseek_r1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

# NOTE - this won't work until we have proper HostMesh support
policy:
engine_config:
engine_args:
model: "deepseek-ai/DeepSeek-R1-0528"
tensor_parallel_size: 16
pipeline_parallel_size: 1
Expand Down
4 changes: 2 additions & 2 deletions tests/sandbox/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/llama3_8b.yaml

policy:
engine_config:
engine_args:
model: "meta-llama/Llama-3.1-8B-Instruct"
tensor_parallel_size: 2
pipeline_parallel_size: 1
Expand All @@ -12,7 +12,7 @@ policy:

services:
policy:
procs: ${policy.engine_config.tensor_parallel_size}
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 4
with_gpus: true

Expand Down
2 changes: 1 addition & 1 deletion tests/sandbox/vllm/qwen2_5_32b.yaml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# >>> python -m tests.sandbox.vllm.main --config tests/sandbox/vllm/qwen2_5_32b.yaml

policy:
engine_config:
engine_args:
model: "Qwen/Qwen2.5-32B"
tensor_parallel_size: 4
pipeline_parallel_size: 1
Expand Down
Loading
Loading