Skip to content

Commit b8319c5

Browse files
committed
merge
2 parents e0e280e + 9ce97bf commit b8319c5

18 files changed

+993
-104
lines changed

.github/workflows/unit_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
- name: Install pytorch
2727
run: python -m pip install torch==2.9.0.dev20250826 --extra-index-url https://download.pytorch.org/whl/nightly/cpu
2828
- name: Install monarch
29-
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/wheels
29+
run: python -m pip install monarch-no-torch==0.1.0.dev20250826 --find-links assets/ci
3030
- name: Install dependencies
3131
run: python -m pip install --no-build-isolation -e ".[dev]"
3232
- name: Run unit tests with coverage

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,5 @@ cover/
187187

188188
# wandb
189189
wandb/
190+
191+
assets/wheels/vllm*.whl

README.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,34 @@
66

77
## Installation
88

9+
### Basic
10+
11+
Forge requires the latest PyTorch nightly with Monarch, vLLM, and torchtitan. For convenience,
12+
we have pre-packaged these dependencies as wheels in assets/wheels.
13+
14+
To install Forge easily:
15+
16+
```bash
17+
conda create -n forge python=3.10
18+
./scripts/install.sh
19+
```
20+
21+
You can test with
22+
```
23+
python -m apps.grpo.main
24+
```
25+
26+
If you need to re-build the wheels for whatever reason, you can do so with:
27+
```bash
28+
conda create -n forge python=3.10
29+
./scripts/build_wheels.sh
30+
```
31+
32+
Since the vLLM wheel is too large for GitHub, we uploaded it as a release:
33+
```
34+
$ gh release create v0.0.0 assets/wheels/vllm-*.whl --title "Forge Wheels v0.0.0"
35+
```
36+
937
### Basic (Broken)
1038

1139
```bash

apps/grpo/main.py

Lines changed: 13 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -13,49 +13,20 @@
1313
import torch
1414
from datasets import load_dataset
1515
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
16+
from forge.actors.reference_actor import compute_sequence_logprobs, TitanRefModel
1617
from forge.actors.replay_buffer import ReplayBuffer
1718
from forge.controller.actor import ForgeActor
1819
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
1920
from forge.data.rewards import MathReward, ThinkingReward
2021
from forge.util.metric_logging import get_metric_logger
2122
from monarch.actor import endpoint
23+
from torchtitan.config.job_config import Model as TitanJobModelConfig
2224
from transformers import AutoModelForCausalLM, AutoTokenizer
2325

2426
logger = logging.getLogger(__name__)
2527
logger.setLevel(logging.DEBUG)
2628

2729

28-
def compute_sequence_logprobs(
29-
model: torch.nn.Module,
30-
input_ids: torch.Tensor,
31-
attention_mask: torch.Tensor,
32-
requires_grad: bool = True,
33-
) -> torch.Tensor:
34-
context_manager = torch.enable_grad() if requires_grad else torch.no_grad()
35-
36-
with context_manager:
37-
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
38-
logits = outputs.logits
39-
40-
# Apply log softmax to get log probabilities
41-
log_probs = torch.log_softmax(logits, dim=-1)
42-
43-
# Extract log probabilities for the actual tokens (excluding the first token for next-token prediction)
44-
shifted_input_ids = input_ids[:, 1:] # Remove first token
45-
shifted_log_probs = log_probs[:, :-1, :] # Remove last logit
46-
47-
# Gather log probabilities for actual tokens
48-
token_log_probs = torch.gather(
49-
shifted_log_probs, dim=-1, index=shifted_input_ids.unsqueeze(-1)
50-
).squeeze(-1)
51-
52-
# Sum log probabilities across sequence (masked by attention)
53-
shifted_attention_mask = attention_mask[:, 1:]
54-
sequence_log_probs = (token_log_probs * shifted_attention_mask).sum(dim=-1)
55-
56-
return sequence_log_probs
57-
58-
5930
@dataclass
6031
class Group:
6132
response: str # The response text for tokenization
@@ -273,48 +244,6 @@ async def __call__(self, groups: list[Group]) -> list[float]:
273244
return advantages
274245

275246

276-
class RefModel(ForgeActor):
277-
def __init__(self, model_name, device: torch.device | None = None):
278-
super().__init__()
279-
self.model_name = model_name
280-
281-
# Set device
282-
if device is None:
283-
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
284-
else:
285-
self.device = device
286-
287-
# Initialize model and tokenizer
288-
self.model = AutoModelForCausalLM.from_pretrained(
289-
model_name,
290-
torch_dtype=torch.bfloat16,
291-
trust_remote_code=True,
292-
).to(self.device)
293-
294-
# Set model to eval mode for reference computations
295-
self.model.eval()
296-
297-
self.logger.info(f"Model initialized on {self.device}")
298-
299-
@endpoint
300-
async def forward(self, token_ids: list[int]) -> torch.Tensor:
301-
# Use provided token_ids directly
302-
input_ids = (
303-
torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
304-
)
305-
# Create attention mask of all 1s since we have actual tokens (no padding)
306-
attention_mask = torch.ones_like(input_ids).to(self.device)
307-
308-
# Compute log probabilities using shared utility function
309-
sequence_log_probs = compute_sequence_logprobs(
310-
self.model, input_ids, attention_mask, requires_grad=False
311-
)
312-
313-
return (
314-
sequence_log_probs.squeeze()
315-
) # Remove batch dimension for single response
316-
317-
318247
class DatasetActor(ForgeActor):
319248
"""Actor wrapper for HuggingFace dataset to provide async interface."""
320249

@@ -345,7 +274,8 @@ async def __next__(self) -> dict[str, str] | None:
345274
async def main():
346275
"""Main GRPO training loop with rollout and training processes."""
347276
group_size = 1
348-
model = "Qwen/Qwen3-1.7B"
277+
model = "Qwen/Qwen3-0.6B"
278+
titan_model = TitanJobModelConfig(name="qwen3", flavor="0.6B")
349279

350280
# ---- Setup WandB Logger ---- #
351281
logger = get_metric_logger(
@@ -403,8 +333,8 @@ async def main():
403333
),
404334
spawn_service(
405335
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
406-
RefModel,
407-
model_name=model,
336+
TitanRefModel,
337+
model=titan_model,
408338
),
409339
spawn_service(
410340
ServiceConfig(procs_per_replica=1, num_replicas=1),
@@ -431,9 +361,14 @@ async def continuous_rollouts():
431361
target=target,
432362
policy_version=version,
433363
)
434-
actions = await policy.generate.choose(prompt)
364+
responses = await policy.generate.choose(prompt)
365+
actions = responses.outputs
435366
for action in actions:
436-
ref_logprobs = await ref_model.forward.choose(action.token_ids)
367+
request_tokens = responses.prompt_token_ids
368+
response_tokens = action.token_ids
369+
ref_logprobs = await ref_model.forward.choose(
370+
request=request_tokens, response=response_tokens
371+
)
437372
reward = await reward_actor.evaluate_response.choose(
438373
prompt=prompt, response=action.text, target=target
439374
)

apps/rl/main.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from forge.controller import spawn_actors
2121
from omegaconf import DictConfig
2222

23-
2423
logger = logging.getLogger(__name__)
2524
logger.setLevel(logging.INFO)
2625

apps/sft_v2/llama3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ processes:
2020
scheduler: local # local | mast (not supported yet)
2121
num_hosts: 1
2222
num_procs: 8
23-
num_gpus: 8
23+
with_gpus: true
2424

2525
optimizer:
2626
name: AdamW

apps/vllm/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
# LICENSE file in the root directory of this source tree.
66

77
"""To run:
8-
8+
export HF_HUB_DISABLE_XET=1
99
python -m apps.vllm.main --guided-decoding --num-samples 3
1010
1111
"""
31.9 MB
Binary file not shown.
7.29 KB
Binary file not shown.

0 commit comments

Comments
 (0)