Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ async def continuous_training():
t.step("update_weights")

if training_step >= 2:
await drop_weights(training_step - 1)
# TODO: figure out why setting to training_step - 1 will trigger error
await drop_weights(training_step - 2)
t.step("drop_weights")

t.stop()
Expand Down
1 change: 1 addition & 0 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ policy:

# Trainer configuration
trainer:
use_dcp: false
model:
name: qwen3
flavor: 1.7B
Expand Down
17 changes: 9 additions & 8 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
# NOTE - This has not been tested for correctness yet! All testing so far has been only for infrastructure stability

# Global configuration
group_size: 2
local_batch_size: 8 # per-device batch size
group_size: 16
local_batch_size: 16
max_req_tokens: 512
max_res_tokens: 512
max_res_tokens: 1536
model: "Qwen/Qwen3-32B"
off_by_n: 1 # Off by one by default

provisioner:
launcher: slurm

# Main loop configuration
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
rollout_threads: 16 # equal to batch size for now

# Observability configuration
metric_logging:
Expand Down Expand Up @@ -48,6 +48,7 @@ policy:

# Trainer configuration
trainer:
use_dcp: false
model:
name: qwen3
flavor: 32B
Expand All @@ -69,8 +70,8 @@ trainer:
enable: false
parallelism:
data_parallel_replicate_degree: 1
data_parallel_shard_degree: -1
tensor_parallel_degree: 1
data_parallel_shard_degree: 1
tensor_parallel_degree: 8
pipeline_parallel_degree: 1
context_parallel_degree: 1
expert_parallel_degree: 1
Expand All @@ -90,7 +91,7 @@ replay_buffer:
batch_size: ${local_batch_size}
max_policy_age: ${off_by_n}
# dp_size: ${trainer.parallelism.data_parallel_shard_degree} # Must equal trainer DP degree
dp_size: 8
dp_size: 1

# Reference model configuration
ref_model:
Expand Down Expand Up @@ -119,7 +120,7 @@ ref_model:
services:
policy:
procs: ${policy.engine_args.tensor_parallel_size}
num_replicas: 1
num_replicas: 4
hosts: 1
with_gpus: true
ref_model:
Expand Down
4 changes: 2 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ policy:

# Trainer configuration
trainer:
use_dcp: true
use_dcp: false
model:
name: qwen3
flavor: 8B
Expand All @@ -53,7 +53,7 @@ trainer:
lr_scheduler:
warmup_steps: 1
training:
local_local_batch_size: ${local_batch_size}
local_batch_size: ${local_batch_size}
seq_len: 2048
max_norm: 1.0
steps: 1000000
Expand Down
8 changes: 8 additions & 0 deletions src/forge/actors/policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import asyncio
import logging
import os
import socket
import sys
from collections.abc import Mapping
from copy import copy
Expand Down Expand Up @@ -61,6 +62,13 @@
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

hostname = socket.gethostname()


async def _ts_parallel_get(keys: list[str]) -> list[torch.Tensor]:
coros = [ts.get(key) for key in keys]
return await asyncio.gather(*coros)


@dataclass
class Policy(PolicyInterface):
Expand Down
30 changes: 16 additions & 14 deletions src/forge/actors/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import asyncio
import logging
import math
import os
Expand All @@ -12,12 +13,23 @@
import time
from collections.abc import Mapping
from dataclasses import dataclass, field, fields
from typing import Callable
from typing import Callable, Iterable

import torch
import torch.distributed.checkpoint as dcp
import torchstore as ts

from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from monarch.actor import current_rank, current_size, endpoint
from torch import Tensor
from torch.distributed.checkpoint._nested_dict import flatten_state_dict
Expand All @@ -38,17 +50,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

from forge.actors._torchstore_utils import (
DcpHandle,
get_dcp_whole_state_dict_key,
get_param_key,
)

from forge.controller import ForgeActor
from forge.data.utils import batch_to_device
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import Tracer

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)

Expand Down Expand Up @@ -111,7 +112,7 @@ class RLTrainer(ForgeActor):
# Non JobConfig-related fields
loss: Callable = lambda logits, **targets: logits
state_dict_key: str = "model_state_dict"
use_dcp: bool = True
use_dcp: bool = False
dcp_path: str = "forge_dcp_tmp"

def __post_init__(self):
Expand Down Expand Up @@ -336,7 +337,8 @@ async def push_weights(self, policy_version: int) -> None:
else:
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)
await ts.put(key, param.detach().cpu())
t.step("ts_save")
t.step("ts_save")
t.stop()
end_time = time.perf_counter()
Expand Down
6 changes: 6 additions & 0 deletions src/forge/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,12 @@ def get_value(self) -> Any:
description="Sets the maximum frame length for Monarch's actor message delivery in bytes.",
)

OMP_NUM_THREADS = EnvVar(
name="OMP_NUM_THREADS",
default=16, # Recommended <= # cores / # of gpus since we are using 1 gpu per process
description="Sets the number of OpenMP threads to use. This is used for CPU-bound operations in PyTorch.",
)


def all_env_vars() -> list[EnvVar]:
"""Retrieves all registered environment variable names."""
Expand Down