Skip to content

Commit fd1d38b

Browse files
authored
Merge branch 'meta-pytorch:main' into main
2 parents 02d77c6 + 28aa995 commit fd1d38b

File tree

8 files changed

+576
-35
lines changed

8 files changed

+576
-35
lines changed

apps/grpo/main.py

Lines changed: 530 additions & 0 deletions
Large diffs are not rendered by default.

apps/rl/llama3_8b.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ trainer:
4646
disable_loss_parallel: false
4747

4848
checkpoint:
49-
enable_checkpoint: true
49+
enable: true
5050
folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints
5151
initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/
5252
initial_load_in_hf: true
@@ -119,7 +119,7 @@ replay_buffer:
119119
# disable_loss_parallel: false
120120
#
121121
# checkpoint:
122-
# enable_checkpoint: true
122+
# enable: true
123123
# folder: /tmp/Meta-Llama-3.1-8B-Instruct/
124124
# interval: 500
125125
# async_mode: "disabled"

apps/sft/llama3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ parallelism:
5555
disable_loss_parallel: false
5656

5757
checkpoint:
58-
enable_checkpoint: true
58+
enable: true
5959
folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints
6060
initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/
6161
initial_load_in_hf: true

apps/sft_v2/llama3_8b.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ parallelism:
4747
disable_loss_parallel: false
4848

4949
checkpoint:
50-
enable_checkpoint: true
50+
enable: true
5151
folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints
5252
initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/
5353
initial_load_in_hf: true

src/forge/actors/policy.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class SamplingOverrides:
6060

6161
num_samples: int
6262
guided_decoding: bool = False
63+
max_tokens: int = 512
6364

6465

6566
@dataclass
@@ -87,6 +88,7 @@ class PolicyConfig:
8788
num_workers: int
8889
worker_params: WorkerConfig
8990
sampling_params: SamplingOverrides
91+
available_devices: str = None
9092

9193

9294
@dataclass
@@ -102,6 +104,11 @@ class Policy(PolicyInterface):
102104
@endpoint
103105
async def setup(self):
104106
# Set up policy_worker
107+
self.available_devices = (
108+
self.config.available_devices
109+
if self.config.available_devices is not None
110+
else ",".join(str(i) for i in range(torch.cuda.device_count()))
111+
)
105112
await self.spawn_workers()
106113

107114
self.request_id = 0
@@ -157,6 +164,7 @@ async def spawn_workers(self):
157164
env={
158165
"MASTER_ADDR": str(get_loopback_ip()),
159166
"MASTER_PORT": str(get_open_port()),
167+
"CUDA_VISIBLE_DEVICES": self.available_devices,
160168
},
161169
)
162170
self.policy_worker = await self.worker_mesh.spawn(
@@ -200,7 +208,6 @@ async def generate(self, prompt: str, priority: int = 0) -> List[CompletionOutpu
200208
if (num_samples := self.sampling_params.n) == 1:
201209
self.output_processor.add_request(request, prompt_str, None, 0)
202210
request, _ = self.preprocess_add_request(request)
203-
204211
request_fut = asyncio.Future()
205212
self.requests[request_id] = (None, request_fut)
206213

@@ -456,7 +463,6 @@ def convert_input(prompt=None, prompt_token_ids=None) -> Dict:
456463

457464
def get_default_sampling_params(vllm_config, overrides=None) -> SamplingParams:
458465
default_params = vllm_config.model_config.get_diff_sampling_param()
459-
default_params["max_tokens"] = 512
460466
if overrides is not None:
461467
default_params |= overrides
462468
if default_params:

src/forge/actors/replay_buffer.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111
from monarch.actor import endpoint
1212

1313
from forge.controller import ForgeActor
14-
from forge.types import Trajectory
1514

1615

1716
@dataclass
@@ -24,50 +23,47 @@ class ReplayBuffer(ForgeActor):
2423

2524
@endpoint
2625
async def setup(self) -> None:
27-
self.buffer: list[Trajectory] = []
26+
self.buffer: list = []
2827
if self.seed is None:
2928
self.seed = random.randint(0, 2**32)
3029
random.seed(self.seed)
3130
self.sampler = random.sample
3231

3332
@endpoint
34-
async def add(self, trajectory: Trajectory) -> None:
35-
self.buffer.append(trajectory)
33+
async def add(self, episode) -> None:
34+
self.buffer.append(episode)
3635

3736
@endpoint
38-
async def sample(
39-
self, curr_policy_version: int, batch_size: int | None = None
40-
) -> list[Trajectory] | None:
37+
async def sample(self, curr_policy_version: int, batch_size: int | None = None):
4138
"""Sample from the replay buffer.
4239
4340
Args:
4441
curr_policy_version (int): The current policy version.
45-
batch_size (int, optional): Number of trajectories to sample. If none, defaults to batch size
42+
batch_size (int, optional): Number of episodes to sample. If none, defaults to batch size
4643
passed in at initialization.
4744
4845
Returns:
49-
A list of sampled trajectories or None if there are not enough trajectories in the buffer.
46+
A list of sampled episodes or None if there are not enough episodes in the buffer.
5047
"""
5148
bsz = batch_size if batch_size is not None else self.batch_size
5249

53-
# Evict old trajectories
50+
# Evict old episodes
5451
self._evict(curr_policy_version)
5552

5653
if bsz > len(self.buffer):
57-
print("Not enough trajectories in the buffer.")
5854
return None
5955

6056
# TODO: Make this more efficient
6157
idx_to_sample = self.sampler(range(len(self.buffer)), k=bsz)
6258
sorted_idxs = sorted(
6359
idx_to_sample, reverse=True
6460
) # Sort in desc order to avoid shifting idxs
65-
sampled_trajectories = [self.buffer.pop(i) for i in sorted_idxs]
66-
return sampled_trajectories
61+
sampled_episodes = [self.buffer.pop(i) for i in sorted_idxs]
62+
return sampled_episodes
6763

6864
@endpoint
6965
async def evict(self, curr_policy_version: int) -> None:
70-
"""Evict trajectories from the replay buffer if they are too old based on the current policy version
66+
"""Evict episodes from the replay buffer if they are too old based on the current policy version
7167
and the max policy age allowed.
7268
7369
Args:
@@ -83,17 +79,17 @@ def _evict(self, curr_policy_version: int) -> None:
8379
]
8480

8581
@endpoint
86-
async def _getitem(self, idx: int) -> Trajectory:
82+
async def _getitem(self, idx: int):
8783
return self.buffer[idx]
8884

8985
@endpoint
9086
async def _numel(self) -> int:
91-
"""Number of elements (trajectories) in the replay buffer."""
87+
"""Number of elements (episodes) in the replay buffer."""
9288
return len(self.buffer)
9389

9490
@endpoint
9591
async def clear(self) -> None:
96-
"""Clear the replay buffer immediately - dropping all trajectories."""
92+
"""Clear the replay buffer immediately - dropping all episodes."""
9793
self.buffer.clear()
9894

9995
@endpoint

src/forge/controller/replica.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@
1313
from enum import Enum
1414
from typing import Optional
1515

16-
from monarch.actor import Actor, ActorError, ProcMesh
17-
1816
from forge.controller import get_proc_mesh
1917
from forge.types import ProcessConfig
2018

19+
from monarch.actor import Actor, ActorError, ProcMesh
20+
2121
logger = logging.getLogger(__name__)
2222
logger.setLevel(logging.DEBUG)
2323

src/forge/util/metric_logging.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import os
77
import sys
88
import time
9-
from typing import Mapping, Optional
9+
from typing import Mapping, Optional, Union
1010

1111
from forge.interfaces import MetricLogger
1212
from forge.types import Scalar
@@ -21,11 +21,12 @@ class StdoutLogger(MetricLogger):
2121
"""Logger to standard output.
2222
2323
Args:
24-
freq (Mapping[str, int]):
25-
calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
24+
freq (Union[int, Mapping[str, int]]):
25+
If int, all metrics will be logged at this frequency.
26+
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
2627
"""
2728

28-
def __init__(self, freq: Mapping[str, int]):
29+
def __init__(self, freq: Union[int, Mapping[str, int]]):
2930
self._freq = freq
3031

3132
def is_log_step(self, name: str, step: int) -> bool:
@@ -35,6 +36,8 @@ def is_log_step(self, name: str, step: int) -> bool:
3536
name (str): metric name (for checking the freq for this metric)
3637
step (int): current step
3738
"""
39+
if isinstance(self._freq, int):
40+
return step % self._freq == 0
3841
return step % self._freq[name] == 0
3942

4043
def log(self, name: str, data: Scalar, step: int) -> None:
@@ -77,8 +80,9 @@ class TensorBoardLogger(MetricLogger):
7780
"""Logger for use w/ PyTorch's implementation of TensorBoard (https://pytorch.org/docs/stable/tensorboard.html).
7881
7982
Args:
80-
freq (Mapping[str, int]):
81-
calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
83+
freq (Union[int, Mapping[str, int]]):
84+
If int, all metrics will be logged at this frequency.
85+
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
8286
log_dir (str): torch.TensorBoard log directory
8387
organize_logs (bool): If `True`, this class will create a subdirectory within `log_dir` for the current
8488
run. Having sub-directories allows you to compare logs across runs. When TensorBoard is
@@ -103,7 +107,7 @@ class TensorBoardLogger(MetricLogger):
103107

104108
def __init__(
105109
self,
106-
freq: Mapping[str, int],
110+
freq: Union[int, Mapping[str, int]],
107111
log_dir: str = "metrics_log",
108112
organize_logs: bool = True,
109113
**kwargs,
@@ -133,6 +137,8 @@ def is_log_step(self, name: str, step: int) -> bool:
133137
name (str): metric name (for checking the freq for this metric)
134138
step (int): current step
135139
"""
140+
if isinstance(self._freq, int):
141+
return step % self._freq == 0
136142
return step % self._freq[name] == 0
137143

138144
def log(self, name: str, data: Scalar, step: int) -> None:
@@ -168,8 +174,9 @@ class WandBLogger(MetricLogger):
168174
For more information about arguments expected by WandB, see https://docs.wandb.ai/ref/python/init.
169175
170176
Args:
171-
freq (Mapping[str, int]):
172-
calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
177+
freq (Union[int, Mapping[str, int]]):
178+
If int, all metrics will be logged at this frequency.
179+
If Mapping, calls to `log` and `log_dict` will be ignored if `step % freq[metric_name] != 0`
173180
log_dir (Optional[str]): WandB log directory.
174181
project (str): WandB project name. Default is `torchtune`.
175182
entity (Optional[str]): WandB entity name. If you don't specify an entity,
@@ -197,7 +204,7 @@ class WandBLogger(MetricLogger):
197204

198205
def __init__(
199206
self,
200-
freq: Mapping[str, int],
207+
freq: Union[int, Mapping[str, int]],
201208
project: str,
202209
log_dir: str = "metrics_log",
203210
entity: Optional[str] = None,
@@ -241,6 +248,8 @@ def is_log_step(self, name: str, step: int) -> bool:
241248
name (str): metric name (for checking the freq for this metric)
242249
step (int): current step
243250
"""
251+
if isinstance(self._freq, int):
252+
return step % self._freq == 0
244253
return step % self._freq[name] == 0
245254

246255
def log(self, name: str, data: Scalar, step: int) -> None:

0 commit comments

Comments
 (0)