Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 8 additions & 13 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from forge.controller.actor import ForgeActor
from forge.controller.provisioner import shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.data.utils import exclude_service
from forge.util.metric_logging import get_metric_logger
from monarch.actor import endpoint
from omegaconf import DictConfig
Expand Down Expand Up @@ -354,19 +353,15 @@ async def main(cfg: DictConfig):
ref_model,
reward_actor,
) = await asyncio.gather(
DatasetActor.options(**cfg.dataset.service).as_service(
**exclude_service(cfg.dataset)
DatasetActor.options(**cfg.services.dataset).as_service(**cfg.dataset),
Policy.options(**cfg.services.policy).as_service(**cfg.policy),
Trainer.options(**cfg.services.trainer).as_service(**cfg.trainer),
ReplayBuffer.options(**cfg.services.replay_buffer).as_service(
**cfg.replay_buffer
),
Policy.options(**cfg.policy.service).as_service(**exclude_service(cfg.policy)),
Trainer.options(**cfg.trainer.service).as_service(
**exclude_service(cfg.trainer)
),
ReplayBuffer.options(**cfg.replay_buffer.service).as_service(
**exclude_service(cfg.replay_buffer)
),
ComputeAdvantages.options(**cfg.compute_advantages.service).as_service(),
RefModel.options(**cfg.ref_model.service).as_service(model_name=model),
RewardActor.options(**cfg.reward_actor.service).as_service(
ComputeAdvantages.options(**cfg.services.compute_advantages).as_service(),
RefModel.options(**cfg.services.ref_model).as_service(**cfg.ref_model),
RewardActor.options(**cfg.services.reward_actor).as_service(
reward_functions=[MathReward(), ThinkingReward()]
),
)
Expand Down
47 changes: 22 additions & 25 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ dataset:
data_split: "train"
streaming: true
model: ${model}
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Policy configuration
policy:
Expand All @@ -31,47 +27,48 @@ policy:
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Reference model configuration
ref_model:
model_name: ${model}

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 1 # Async by 1
dp_size: 1
service:

services:
dataset:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Compute advantages configuration
compute_advantages:
service:
policy:
procs_per_replica: 1
num_replicas: 1
with_gpus: true
trainer:
procs_per_replica: 1
num_replicas: 1
with_gpus: true
replay_buffer:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Reference model configuration
ref_model:
service:
ref_model:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Reward actor configuration
reward_actor:
service:
compute_advantages:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
reward_actor:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
51 changes: 24 additions & 27 deletions apps/grpo/qwen3_multinode.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@ dataset:
data_split: "train"
streaming: true
model: ${model}
service:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Policy configuration
policy:
Expand All @@ -33,49 +29,50 @@ policy:
max_tokens: ${max_res_tokens}
temperature: 1.0
top_p: 1.0
service:
procs_per_replica: 1
hosts_per_replica: 1
num_replicas: 1
with_gpus: true

# Trainer configuration
trainer:
model_name: ${model}
learning_rate: 1e-5
service:
procs_per_replica: 1
hosts_per_replica: 1
num_replicas: 1
with_gpus: true

# Replay buffer configuration
replay_buffer:
batch_size: ${batch_size}
max_policy_age: 1 # Async by 1
dp_size: 1
service:

# Reference model configuration
ref_model:
model_name: ${model}

services:
dataset:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Compute advantages configuration
compute_advantages:
service:
policy:
procs_per_replica: 1
hosts_per_replica: 1
num_replicas: 1
with_gpus: true
trainer:
procs_per_replica: 1
hosts_per_replica: 1
num_replicas: 1
with_gpus: true
replay_buffer:
procs_per_replica: 1
num_replicas: 1
with_gpus: false

# Reference model configuration
ref_model:
service:
compute_advantages:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
ref_model:
procs_per_replica: 1
num_replicas: 1
with_gpus: true

# Reward actor configuration
reward_actor:
service:
reward_actor:
procs_per_replica: 1
num_replicas: 1
with_gpus: false
5 changes: 3 additions & 2 deletions apps/vllm/deepseek_r1.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ policy:
n: 2
guided_decoding: false
max_tokens: 512
available_devices: null
service:

services:
policy:
procs_per_replica: 8
hosts_per_replica: 2
num_replicas: 1
Expand Down
4 changes: 3 additions & 1 deletion apps/vllm/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@ policy:
n: 2
guided_decoding: false
max_tokens: 512
service:

services:
policy:
procs_per_replica: 2
num_replicas: 1
with_gpus: true
Expand Down
5 changes: 1 addition & 4 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from forge.controller.provisioner import shutdown

from omegaconf import DictConfig
from src.forge.data.utils import exclude_service
from vllm.outputs import RequestOutput

os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
Expand All @@ -32,9 +31,7 @@ async def run(cfg: DictConfig):
prompt = "What is 3+5?" if gd else "Tell me a joke"

print("Spawning service...")
policy = await Policy.options(**cfg.policy.service).as_service(
**exclude_service(cfg.policy)
)
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)

try:
async with policy.session():
Expand Down
5 changes: 3 additions & 2 deletions apps/vllm/qwen2_5_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@ policy:
n: 2
guided_decoding: false
max_tokens: 512
available_devices: null
service:

services:
policy:
procs_per_replica: 4
hosts_per_replica: 1
num_replicas: 1
Expand Down
2 changes: 1 addition & 1 deletion src/forge/controller/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def as_service(cls: Type[T], **actor_kwargs) -> "ServiceInterface":
# dynamically create a configured subclass for consistency
cls = type(f"{cls.__name__}Configured", (cls,), {"_service_config": cfg})

logger.info(("Spawning Service Actor for %s", cls.__name__))
logger.info("Spawning Service Actor for %s", cls.__name__)
service = Service(cfg, cls, actor_kwargs)
await service.__initialize__()
return ServiceInterface(service, cls)
Expand Down
7 changes: 0 additions & 7 deletions src/forge/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,3 @@ def batch_to_device(batch: dict, device: torch.device) -> None:
f"Tensor, or BlockMask with flexattention enabled. "
f'Got key "{k}" with value of type {type(v)}'
)


def exclude_service(config_dict: dict) -> dict:
"""Remove 'service' key from config dict without modifying original."""
result = config_dict.copy()
result.pop("service", None)
return result
Loading