Skip to content

Commit f79beee

Browse files
authored
Merge branch 'meta-pytorch:main' into main
2 parents fd1d38b + ddd0794 commit f79beee

File tree

6 files changed

+46
-71
lines changed

6 files changed

+46
-71
lines changed

.github/workflows/unit_test.yaml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@ jobs:
2323
python-version: ${{ matrix.python-version }}
2424
- name: Update pip
2525
run: python -m pip install --upgrade pip
26+
- name: Install pytorch
27+
run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
28+
- name: Install monarch
29+
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/wheels
2630
- name: Install dependencies
27-
run: |
28-
python -m pip install --no-build-isolation -e ".[dev,cpu]" --extra-index-url https://download.pytorch.org/whl/nightly/cpu --find-links assets/wheels
31+
run: python -m pip install --no-build-isolation -e ".[dev]"
2932
- name: Run unit tests with coverage
3033
# TODO add all tests
3134
run: pytest tests/unit_tests --cov=. --cov-report=xml --durations=20 -vv

README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,14 @@ source .venv/bin/activate
3535

3636
```bash
3737
# feature install if you don't have /user/local/cuda-12.8
38-
feature install --persist cuda_12_8
38+
feature install --persist cuda_12_9
3939

4040
# add env variables
41-
export CUDA_VERSION=12.8
42-
export NVCC=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
43-
export CUDA_NVCC_EXECUTABLE=/usr/local/cuda-${CUDA_VERSION}/bin/nvcc
44-
export CUDA_HOME=/usr/local/cuda-${CUDA_VERSION}
45-
export PATH="${CUDA_HOME}/bin:$PATH"
41+
export CUDA_VERSION=12.9
42+
export NVCC=/usr/local/cuda-$CUDA_VERSION/bin/nvcc
43+
export CUDA_NVCC_EXECUTABLE=/usr/local/cuda-$CUDA_VERSION/bin/nvcc
44+
export CUDA_HOME=/usr/local/cuda-$CUDA_VERSION
45+
export PATH="$CUDA_HOME/bin:$PATH"
4646
export CUDA_INCLUDE_DIRS=$CUDA_HOME/include
4747
export CUDA_CUDART_LIBRARY=$CUDA_HOME/lib64/libcudart.so
4848
export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH

apps/grpo/main.py

Lines changed: 19 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
import torch
1313
from datasets import load_dataset
1414
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
15+
from forge.actors.replay_buffer import ReplayBuffer
1516
from forge.controller import ServiceConfig, spawn_service
1617
from forge.controller.actor import ForgeActor
18+
from forge.data.rewards import MathReward, ThinkingReward
19+
from forge.util.metric_logging import get_metric_logger
1720
from monarch.actor import endpoint
1821
from transformers import AutoModelForCausalLM, AutoTokenizer
1922

@@ -209,66 +212,18 @@ async def update_weights(self, policy_actor):
209212
self.logger.info(f"Updating weights took {end_time - start_time:.2f} seconds")
210213

211214

212-
def math_scoring_function(prompt: str, response: str, target: str) -> float:
213-
"""Function to score math correctness."""
214-
import re
215-
216-
# Extract expected answer from target
217-
expected_answer = (
218-
float(target.strip())
219-
if target.strip().replace(".", "").replace("-", "").isdigit()
220-
else None
221-
)
222-
223-
# Extract model answer from response
224-
patterns = [
225-
r"####\s*([+-]?\d+(?:\.\d+)?)", # GSM8K style answer format
226-
r"(?:the\s+)?answer\s+is\s*([+-]?\d+(?:\.\d+)?)",
227-
r"(?:answer:|result:)\s*([+-]?\d+(?:\.\d+)?)",
228-
r"=\s*([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # equals near end
229-
r"\b([+-]?\d+(?:\.\d+)?)\s*(?:\.|$)", # number at end
230-
r"([+-]?\d+(?:\.\d+)?)", # any number (fallback)
231-
]
232-
233-
model_answer = None
234-
response_lower = response.lower().strip()
235-
for pattern in patterns:
236-
matches = re.findall(pattern, response_lower)
237-
if matches:
238-
model_answer = float(matches[-1])
239-
break
240-
241-
if expected_answer is None or model_answer is None:
242-
return 0.1 # Partial credit for attempting
243-
244-
# Check if answers match (with some tolerance for floating point)
245-
if abs(expected_answer - model_answer) < 1e-6:
246-
return 1.0 # Correct answer
247-
else:
248-
return 0.0 # Incorrect answer
249-
250-
251-
def thinking_scoring_function(prompt: str, response: str, target: str) -> float:
252-
"""Function to score thinking tag usage."""
253-
# Check if response contains <think></think> tags
254-
if "<think>" in response.lower() and "</think>" in response.lower():
255-
return 0.5
256-
else:
257-
return 0.0
258-
259-
260215
class RewardActor(ForgeActor):
261216
"""Reward actor that uses a list of scoring functions."""
262217

263-
def __init__(self, scoring_functions: list[Callable]):
218+
def __init__(self, reward_functions: list[Callable]):
264219
super().__init__()
265-
self.scoring_functions = scoring_functions
220+
self.reward_functions = reward_functions
266221

267222
@endpoint
268223
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
269224
total_reward = 0.0
270-
for scoring_fn in self.scoring_functions:
271-
reward = scoring_fn(prompt, response, target)
225+
for reward_fn in self.reward_functions:
226+
reward = reward_fn(prompt, response, target)
272227
total_reward += reward
273228
return total_reward
274229

@@ -388,6 +343,13 @@ async def main():
388343
group_size = 1
389344
model = "Qwen/Qwen3-1.7B"
390345

346+
# ---- Setup WandB Logger ---- #
347+
logger = get_metric_logger(
348+
"wandb",
349+
freq=1,
350+
project="grpo-training",
351+
)
352+
391353
# ---- Setup services ---- #
392354
default_service_cfg = ServiceConfig(
393355
procs_per_replica=1,
@@ -447,7 +409,7 @@ async def main():
447409
reward_actor = await spawn_service(
448410
default_service_cfg,
449411
RewardActor,
450-
scoring_functions=[math_scoring_function, thinking_scoring_function],
412+
reward_functions=[MathReward(), ThinkingReward()],
451413
)
452414

453415
print("All services initialized successfully!")
@@ -498,6 +460,7 @@ async def continuous_rollouts():
498460
print(
499461
f"Generated {rollout_count} rollouts w/ average reward {avg_reward}"
500462
)
463+
logger.log("reward/rollout", avg_reward, rollout_count)
501464

502465
async def continuous_training():
503466
training_step = 0
@@ -511,7 +474,9 @@ async def continuous_training():
511474
if training_step % 10 == 0:
512475
print(f"Completed {training_step} training steps")
513476
if training_result:
514-
print(f"Latest loss: {training_result.get('loss', 'N/A')}")
477+
loss_value = training_result.get("loss", 0.0)
478+
print(f"Latest loss: {loss_value}")
479+
logger.log("loss/training_step", loss_value, training_step)
515480
# await trainer.update_weights(policy)
516481

517482
print("Starting GRPO training loops...")

pyproject.toml

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,6 @@ oss = [
4747
"torchmonarch-nightly==2025.8.1",
4848
"torchstore",
4949
]
50-
cpu = [
51-
"torch==2.9.0.dev20250826",
52-
"monarch-no-torch==0.1.0.dev20250826",
53-
]
54-
5550

5651
[project.scripts]
5752
forge = "forge.cli.forge:main"
@@ -78,8 +73,8 @@ members = [
7873
# pytorch
7974
# TODO: get auto backend to work
8075
[[tool.uv.index]]
81-
name = "pytorch-nightly-cu128"
82-
url = "https://download.pytorch.org/whl/nightly/cu128"
76+
name = "pytorch-nightly-cu129"
77+
url = "https://download.pytorch.org/whl/nightly/cu129"
8378
#explicit = true
8479

8580
# vllm
@@ -89,8 +84,8 @@ url = "https://download.pytorch.org/whl/nightly/cu128"
8984
# explicit = true
9085

9186
[tool.uv.sources]
92-
torchtitan = { index = "pytorch-nightly-cu128" }
93-
torch = { index = "pytorch-nightly-cu128" }
87+
torchtitan = { index = "pytorch-nightly-cu129" }
88+
torch = { index = "pytorch-nightly-cu129" }
9489
torchstore = { git = "ssh://[email protected]/meta-pytorch/torchstore.git" }
9590
#vllm = { index = "vllm-nightly" }
9691

src/forge/cli/config.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import argparse
28
import functools
39
import sys

src/forge/data/rewards.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
import re
28
from typing import Optional
39

0 commit comments

Comments
 (0)