Skip to content

Commit 41bdd93

Browse files
committed
Debugging Cuda issue
1 parent 3b7ee6d commit 41bdd93

File tree

6 files changed

+234
-19
lines changed

6 files changed

+234
-19
lines changed

apps/grpo/main.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,18 @@
1313
import torch
1414
from datasets import load_dataset
1515
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
16-
from forge.actors.reference_actor import compute_sequence_logprobs, RefModel
16+
from forge.actors.reference_actor import (
17+
compute_sequence_logprobs,
18+
HuggingFaceRefModel,
19+
TitanRefModel,
20+
)
1721
from forge.actors.replay_buffer import ReplayBuffer
1822
from forge.controller.actor import ForgeActor
1923
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
2024
from forge.data.rewards import MathReward, ThinkingReward
2125
from forge.util.metric_logging import get_metric_logger
2226
from monarch.actor import endpoint
27+
from torchtitan.config.job_config import Model as TitanJobModelConfig
2328
from transformers import AutoModelForCausalLM, AutoTokenizer
2429

2530
logger = logging.getLogger(__name__)
@@ -329,10 +334,16 @@ async def main():
329334
gamma=0.99,
330335
lambda_=0.95,
331336
),
337+
# spawn_service(
338+
# ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
339+
# RefModel,
340+
# model_name=model,
341+
# ),
342+
# GOAL: Swap this in and everything should just "work"
332343
spawn_service(
333344
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
334-
RefModel,
335-
model_name=model,
345+
TitanRefModel,
346+
model=TitanJobModelConfig(name=model),
336347
),
337348
spawn_service(
338349
ServiceConfig(procs_per_replica=1, num_replicas=1),

apps/grpo/test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import asyncio
2+
3+
from datasets import load_dataset
4+
5+
from forge.actors.policy import Policy, PolicyConfig, SamplingOverrides, WorkerConfig
6+
from forge.actors.reference_actor import HuggingFaceRefModel, TitanRefModel
7+
8+
from forge.controller.actor import ForgeActor
9+
from forge.controller.service import ServiceConfig, shutdown_service, spawn_service
10+
from monarch.actor import endpoint
11+
12+
13+
class DatasetActor(ForgeActor):
14+
"""Actor wrapper for HuggingFace dataset to provide async interface."""
15+
16+
def __init__(
17+
self, path: str, config_name: str, split: str, streaming: bool, **kwargs
18+
):
19+
super().__init__()
20+
21+
def gsm8k_to_messages(sample):
22+
question = sample["question"]
23+
full_answer: str = sample["answer"]
24+
answer = full_answer.split("#### ")[1]
25+
return {"question": question, "answer": answer}
26+
27+
ds = load_dataset(path, config_name, split=split, streaming=streaming)
28+
ds = ds.map(gsm8k_to_messages)
29+
ds = ds.shuffle()
30+
self._iterator = iter(ds)
31+
32+
@endpoint
33+
async def __next__(self) -> dict[str, str] | None:
34+
return next(self._iterator)
35+
36+
37+
# Sandbox; will be removed
38+
async def main():
39+
group_size = 1
40+
41+
# For Torchtitan
42+
model = "Qwen/Qwen3-1.7B"
43+
44+
# Spawn Reference "Agents"
45+
hf_model = await spawn_service(
46+
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
47+
HuggingFaceRefModel,
48+
model_name=model,
49+
)
50+
titan_model = await spawn_service(
51+
ServiceConfig(procs_per_replica=1, num_replicas=1, with_gpus=True),
52+
TitanRefModel,
53+
)
54+
55+
# Spawn Policy for getting responses
56+
policy = await spawn_service(
57+
ServiceConfig(procs_per_replica=1, with_gpus=True, num_replicas=1),
58+
Policy,
59+
config=PolicyConfig(
60+
worker_params=WorkerConfig(model=model),
61+
sampling_params=SamplingOverrides(num_samples=group_size, max_tokens=16),
62+
),
63+
)
64+
65+
# Load Dataset
66+
dataloader = await spawn_service(
67+
ServiceConfig(procs_per_replica=1, num_replicas=1),
68+
DatasetActor,
69+
path="openai/gsm8k",
70+
config_name="main",
71+
split="train",
72+
streaming=True,
73+
)
74+
sample = await dataloader.__next__.choose()
75+
prompt, target = sample["question"], sample["answer"]
76+
print("Sample: ", sample)
77+
78+
# Generate output from policy, then pass to reference agents
79+
actions = await policy.generate.choose(prompt)
80+
for action in actions:
81+
print("Generated Action tok_ids: ", action.token_ids)
82+
83+
print("HuggingFace Results")
84+
hf_logprobs: float = await hf_model.forward.choose(action.token_ids)
85+
print("HF logprob: ", hf_logprobs)
86+
87+
print("Titan Results")
88+
titan_logprobs: float = await titan_model.forward.choose(action.token_ids)
89+
print("Titan logprob: ", titan_logprobs)
90+
# TODO: Update forward to convert probs (logits) to logprobs
91+
92+
await asyncio.gather(
93+
shutdown_service(policy),
94+
shutdown_service(dataloader),
95+
shutdown_service(hf_model),
96+
shutdown_service(titan_model),
97+
)
98+
99+
100+
if __name__ == "__main__":
101+
asyncio.run(main())

apps/rl/main.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727

2828
async def run(cfg: DictConfig):
29-
trainer, buffer, reference = await asyncio.gather(
29+
trainer, buffer = await asyncio.gather(
3030
spawn_actors(
3131
name="trainer",
3232
actor_cls=RLTrainer,
@@ -40,24 +40,18 @@ async def run(cfg: DictConfig):
4040
cfg=cfg.replay_buffer,
4141
processes=cfg.replay_buffer.pop("processes"),
4242
),
43-
spawn_actors(
44-
name="reference_actor",
45-
actor_cls=ReferenceActor,
46-
),
4743
)
4844
print("Actors spawned")
4945

5046
# Initialize everything
5147
await asyncio.gather(
5248
buffer.setup.call(),
5349
trainer.setup.call(),
54-
reference.setup.call(),
5550
)
5651
print("Setup done")
5752

5853
print("shutting down...")
5954
await asyncio.gather(*[a.mesh.stop() for a in [trainer]])
60-
await reference.cleanup.call()
6155

6256

6357
@parse

src/forge/actors/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer"]
7+
__all__ = ["Policy", "PolicyRouter", "RLTrainer", "ReplayBuffer", "TitanRefModel"]
88

99

1010
def __getattr__(name):
@@ -24,5 +24,9 @@ def __getattr__(name):
2424
from .replay_buffer import ReplayBuffer
2525

2626
return ReplayBuffer
27+
elif name == "TitanRefModel":
28+
from .reference_actor import TitanRefModel
29+
30+
return TitanRefModel
2731
else:
2832
raise AttributeError(f"module {__name__} has no attribute {name}")

src/forge/actors/reference_actor.py

Lines changed: 111 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,13 @@
3737

3838
@dataclass
3939
class ReferenceActor(ForgeActor):
40+
"""
41+
Original idea (not updated); On second throught this might be overkill
42+
if we can rely on the Service Replicas to handle the queue since there's no
43+
real pre/post proc or host management (maybe later for DP?). For now just
44+
directly spin up services of the reference models
45+
"""
46+
4047
model: Model = field(default_factory=Model)
4148
# parallelism: Parallelism = field(default_factory=Parallelism)
4249
# comm: Comm = field(default_factory=Comm)
@@ -95,13 +102,18 @@ async def setup(self):
95102
# Spawn the RefModel
96103
self.ref_model = await spawn_service(
97104
default_service_cfg,
98-
RefModel,
105+
HuggingFaceRefModel,
99106
model_name=self.model.name,
100107
device=self.device,
101108
)
102109

103110
# Kick off background processing
104-
asyncio.create_task(self.run_processing.call())
111+
self.start_processing()
112+
113+
def start_processing(self):
114+
"""Start the replica's processing loop if not already running."""
115+
if self._run_task is None or self._run_task.done():
116+
self._run_task = asyncio.create_task(self.run())
105117

106118
@endpoint
107119
async def forward(self, token_ids: list[int]) -> torch.Tensor:
@@ -112,8 +124,7 @@ async def forward(self, token_ids: list[int]) -> torch.Tensor:
112124
self.queue.append((token_ids, fut))
113125
return await fut
114126

115-
@endpoint
116-
async def run_processing(self):
127+
async def run(self):
117128
"""
118129
Simple loop to pass things along to the ref model
119130
"""
@@ -127,11 +138,105 @@ async def run_processing(self):
127138
fut.set_result(model_output)
128139

129140
@endpoint
130-
async def cleanup(self) -> None:
141+
async def stop(self) -> None:
131142
self.running = False
132143

133144

134-
class RefModel(ForgeActor):
145+
@dataclass
146+
class TitanRefModel(ForgeActor):
147+
"""
148+
Represents a reference actor leveraging a torchtitan model for execution
149+
"""
150+
151+
# Refer to titan JobConfig for enabling more ForgeEngine configuration
152+
model: Model = field(default_factory=Model)
153+
parallelism: Parallelism = field(default_factory=Parallelism)
154+
155+
# Populated in setup (commented out for now for engine_config parsing)
156+
# engine: ForgeEngine | None = None
157+
158+
def __post_init__(self):
159+
"""Initializes config types and env variables."""
160+
# Instantiate dict fields
161+
for f in fields(self):
162+
attr = getattr(self, f.name)
163+
if isinstance(attr, Mapping):
164+
setattr(self, f.name, f.type(**attr))
165+
elif not isinstance(attr, f.type):
166+
raise TypeError(
167+
f"{f.name} should be a {f.type} type or a dict like object"
168+
)
169+
170+
"""
171+
torchrun normally hands env variables, but we need to do it ourselves
172+
in monarch for now.
173+
"""
174+
self.rank = current_rank().rank
175+
self.size = math.prod(current_size().values())
176+
177+
env = {
178+
"RANK": str(self.rank),
179+
"LOCAL_RANK": str(self.rank),
180+
"LOCAL_WORLD_SIZE": str(self.size),
181+
"GROUP_RANK": str(self.size),
182+
"GROUP_WORLD_SIZE": str(self.size),
183+
"ROLE_RANK": str(self.rank),
184+
"ROLE_WORLD_SIZE": str(self.size),
185+
"ROLE_NAME": "rank",
186+
"WORLD_SIZE": str(self.size),
187+
"PYTORCH_CUDA_ALLOC_CONF": "expandable_segments:True",
188+
}
189+
os.environ.update(env)
190+
191+
@endpoint
192+
async def setup(self):
193+
engine_config = {f.name: getattr(self, f.name) for f in fields(self)}
194+
self.engine = ForgeEngine(ForgeJobConfig(**engine_config))
195+
196+
@endpoint
197+
async def forward(self, token_ids: list[int]) -> torch.Tensor:
198+
"""
199+
Given a return the log_probability of the token_ids
200+
(Used as the reference_logprobs for KL Divergence)
201+
"""
202+
model_parts = self.engine.model_parts
203+
parallel_dims = self.engine.parallel_dims
204+
205+
# Use provided token_ids directly
206+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207+
input_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(device)
208+
209+
optional_context_parallel_ctx = (
210+
dist_utils.create_context_parallel_ctx(
211+
cp_mesh=parallel_dims.world_mesh["cp"],
212+
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
213+
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
214+
cp_no_restore_buffers={inputs, labels},
215+
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
216+
)
217+
if parallel_dims.cp_enabled
218+
else None
219+
)
220+
221+
if parallel_dims.pp_enabled:
222+
raise NotImplementedError("PP not implemented yet")
223+
else:
224+
# Non-PP forward / backward
225+
with self.engine.train_context(optional_context_parallel_ctx):
226+
assert len(model_parts) == 1
227+
with self.engine.maybe_enable_amp:
228+
pred = model_parts[0](input_ids)
229+
230+
# TODO: Update compute_sequence_logprobs to convert probs (logits) to logprobs
231+
return pred
232+
233+
234+
# Maintained to keep GRPO app prior to migration
235+
class HuggingFaceRefModel(ForgeActor):
236+
"""
237+
Represents a reference actor leveraging HuggingFace for execution
238+
"""
239+
135240
def __init__(self, model_name, device: torch.device | None = None):
136241
super().__init__()
137242
self.model_name = model_name

src/forge/actors/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
from dataclasses import dataclass, field, fields
1313

1414
import torch
15+
16+
from forge.controller import ForgeActor
1517
from monarch.actor import current_rank, current_size, endpoint
1618
from torchtitan.config.job_config import (
1719
ActivationCheckpoint,
@@ -30,8 +32,6 @@
3032
from torchtitan.experiments.forge.engine import ForgeEngine
3133
from torchtitan.experiments.forge.job_config import ForgeJobConfig
3234

33-
from forge.controller import ForgeActor
34-
3535
logger = logging.getLogger(__name__)
3636
logger.setLevel(logging.INFO)
3737

0 commit comments

Comments
 (0)