Skip to content

Commit 579697b

Browse files
committed
merge conflicts
2 parents 021727c + b77d6ad commit 579697b

File tree

15 files changed

+160
-50
lines changed

15 files changed

+160
-50
lines changed

.github/workflows/docs.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ jobs:
4040
run: python -m pip install torch==2.9.0 --index-url https://download.pytorch.org/whl/test/cu130
4141
- name: Install monarch
4242
shell: bash -l {0}
43-
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci
43+
run: pip install assets/ci/monarch_no_torch-0.1.0.dev20251010-py3-none-any.whl
4444
- name: Install torchforge
4545
shell: bash -l {0}
4646
env:

.github/workflows/unit_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Install pytorch
2727
run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
2828
- name: Install monarch
29-
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci
29+
run: pip install assets/ci/monarch_no_torch-0.1.0.dev20251010-py3-none-any.whl
3030
- name: Install torchstore
3131
run: pip install assets/wheels/torchstore-0.1.0-py3-none-any.whl
3232
- name: Install torchtitan

apps/grpo/main.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from forge.controller.actor import ForgeActor
2929
from forge.controller.provisioner import init_provisioner, shutdown
3030
from forge.data.rewards import MathReward, ThinkingReward
31+
from forge.env import MONARCH_HOSTMESH_V1
3132
from forge.observability.metric_actors import get_or_create_metric_logger
3233
from forge.observability.metrics import record_metric, Reduce
3334
from forge.observability.perf_tracker import Tracer
@@ -314,14 +315,23 @@ async def main(cfg: DictConfig):
314315
max_res_tokens = cfg.max_res_tokens
315316

316317
# ---- Global setups ---- #
318+
provisioner = None
317319
if cfg.get("provisioner", None) is not None:
318-
await init_provisioner(
320+
provisioner = await init_provisioner(
319321
ProvisionerConfig(launcher_config=LauncherConfig(**cfg.provisioner))
320322
)
323+
else:
324+
provisioner = await init_provisioner()
325+
321326
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
322327
mlogger = await get_or_create_metric_logger()
323328
await mlogger.init_backends.call_one(metric_logging_cfg)
324-
await ts.initialize(strategy=ts.ControllerStorageVolumes())
329+
330+
# In the host mesh v0 case, actors on remote hosts are not able to communicate
331+
# with one another. Therefore we use the controller as our storage volume.
332+
if not MONARCH_HOSTMESH_V1.get_value():
333+
await ts.initialize(strategy=ts.ControllerStorageVolumes())
334+
print("Torchstore successfully initialized with controller storage strategy")
325335

326336
# ---- Setup services ---- #
327337

@@ -351,6 +361,22 @@ async def main(cfg: DictConfig):
351361

352362
print("All services initialized successfully!")
353363

364+
# In the HostMesh v1 case, we spawn a torchstore storage volume
365+
# per trainer process.
366+
# We initialize after service initialization because torchstore currently
367+
# requires access to the underlying proc meshes in the local rank strategy.
368+
# We should be able to hide this in the future.
369+
if MONARCH_HOSTMESH_V1.get_value():
370+
# TODO: support multiple host meshes
371+
trainer_num_procs = cfg.actors.trainer["procs"]
372+
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
373+
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
374+
await ts.initialize(
375+
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
376+
strategy=ts.LocalRankStrategy(),
377+
)
378+
print("Torchstore successfully initialized with local rank strategy")
379+
354380
# ---- Core RL loops ---- #
355381
async def continuous_rollouts():
356382
rollout_count = 0

apps/grpo/qwen3_1_7b.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,26 +117,33 @@ services:
117117
policy:
118118
procs: ${policy.engine_args.tensor_parallel_size}
119119
num_replicas: 1
120+
mesh_name: policy
120121
with_gpus: true
121122
ref_model:
122123
procs: 1
123124
num_replicas: 1
125+
mesh_name: ref_model
124126
with_gpus: true
125127
reward_actor:
126128
procs: 1
127129
num_replicas: 1
130+
mesh_name: reward_actor
128131
with_gpus: false
129132

130133
actors:
131134
dataset:
132135
procs: 1
133136
with_gpus: false
137+
mesh_name: dataset
134138
trainer:
135139
procs: 1
136140
with_gpus: true
141+
mesh_name: trainer
137142
replay_buffer:
138143
procs: 1
139144
with_gpus: false
145+
mesh_name: replay_buffer
140146
compute_advantages:
141147
procs: 1
142148
with_gpus: false
149+
mesh_name: compute_advantages

apps/grpo/qwen3_32b.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,26 +122,33 @@ services:
122122
num_replicas: 1
123123
hosts: 1
124124
with_gpus: true
125+
mesh_name: policy
125126
ref_model:
126127
procs: ${ref_model.parallelism.tensor_parallel_degree}
127128
num_replicas: 1
128129
with_gpus: true
130+
mesh_name: ref_model
129131
reward_actor:
130132
procs: 1
131133
num_replicas: 1
132134
with_gpus: false
135+
mesh_name: reward_actor
133136

134137
actors:
135138
dataset:
136139
procs: 1
137140
with_gpus: false
141+
mesh_name: dataset
138142
trainer:
139143
procs: 8
140144
hosts: 1
141145
with_gpus: true
146+
mesh_name: trainer
142147
replay_buffer:
143148
procs: 1
144149
with_gpus: false
150+
mesh_name: replay_buffer
145151
compute_advantages:
146152
procs: 1
147153
with_gpus: false
154+
mesh_name: compute_advantages

apps/grpo/qwen3_8b.yaml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,25 +117,32 @@ services:
117117
procs: ${policy.engine_args.tensor_parallel_size}
118118
num_replicas: 1
119119
with_gpus: true
120+
mesh_name: policy
120121
ref_model:
121122
procs: 1
122123
num_replicas: 1
123124
with_gpus: true
125+
mesh_name: ref_model
124126
reward_actor:
125127
procs: 1
126128
num_replicas: 1
127129
with_gpus: false
130+
mesh_name: reward_actor
128131

129132
actors:
130133
dataset:
131134
procs: 1
132135
with_gpus: false
136+
mesh_name: dataset
133137
trainer:
134138
procs: 2
135139
with_gpus: true
140+
mesh_name: trainer
136141
replay_buffer:
137142
procs: 1
138143
with_gpus: false
144+
mesh_name: replay_buffer
139145
compute_advantages:
140146
procs: 1
141147
with_gpus: false
148+
mesh_name: compute_advantages
758 KB
Binary file not shown.

src/forge/actors/policy.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from forge.data.sharding import VLLMSharding
5454
from forge.data_models.completion import Completion
5555
from forge.data_models.prompt import to_prompt
56+
from forge.env import TORCHSTORE_USE_RDMA
5657
from forge.interfaces import Policy as PolicyInterface
5758
from forge.observability.metrics import record_metric, Reduce
5859
from forge.observability.perf_tracker import Tracer
@@ -67,7 +68,9 @@ class Policy(PolicyInterface):
6768
engine_args: EngineArgs | Mapping = field(default_factory=EngineArgs)
6869
sampling_params: SamplingParams | Mapping = field(default_factory=SamplingParams)
6970
available_devices: str | None = None
70-
use_dcp: bool = True
71+
use_dcp: bool = (
72+
TORCHSTORE_USE_RDMA.get_value() == 0
73+
) # torchstore currently only accepts 0 or 1
7174
# Gets set up by setup
7275
lora_request: LoRARequest | None = None
7376
tokenization_kwargs: dict = field(default_factory=dict)
@@ -83,7 +86,7 @@ def __post_init__(self):
8386

8487
if isinstance(self.engine_args, Mapping):
8588
self.engine_args = EngineArgs(**self.engine_args)
86-
self.engine_args._is_v1_supported_oracle = lambda *_: True
89+
self.engine_args._is_v1_supported_oracle = lambda *_: True
8790

8891
if isinstance(self.sampling_params, Mapping):
8992
self.sampling_params = SamplingParams.from_optional(**self.sampling_params)

src/forge/actors/trainer.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
from forge.controller import ForgeActor
4848
from forge.data.utils import batch_to_device
49+
from forge.env import TORCHSTORE_USE_RDMA
4950
from forge.observability.metrics import record_metric, Reduce
5051
from forge.observability.perf_tracker import Tracer
5152

@@ -111,7 +112,9 @@ class RLTrainer(ForgeActor):
111112
# Non JobConfig-related fields
112113
loss: Callable = lambda logits, **targets: logits
113114
state_dict_key: str = "model_state_dict"
114-
use_dcp: bool = True
115+
use_dcp: bool = (
116+
TORCHSTORE_USE_RDMA.get_value() == 0
117+
) # torchstore currently only accepts 0 or 1
115118
dcp_path: str = "forge_dcp_tmp"
116119

117120
def __post_init__(self):

0 commit comments

Comments
 (0)