Skip to content

Commit 2054d63

Browse files
author
Allen Wang
committed
Merge branch 'main' into replica
2 parents 7d6b247 + 5a3807e commit 2054d63

File tree

6 files changed

+632
-16
lines changed

6 files changed

+632
-16
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@ training:
3131
max_norm: 1.0
3232
steps: 1000
3333
compile: false
34-
dataset: "c4"
34+
35+
validation:
36+
local_batch_size: 1
37+
freq: -1 # Change to a positive number to enable validation
38+
steps: 200 # Max steps to run validation. Validation disabled if negative.
39+
40+
dataset:
41+
path: yahma/alpaca-cleaned
42+
split: train[:95%]
43+
44+
dataset_val:
45+
path: yahma/alpaca-cleaned
46+
split: train[95%:]
3547

3648
parallelism:
3749
data_parallel_replicate_degree: 1

apps/sft/main.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,11 @@
1818
from forge.data.datasets.packed import PackedDataset, TextPacker
1919
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
2020
from forge.data.tokenizer import HuggingFaceModelTokenizer
21+
from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX
2122

2223
from omegaconf import DictConfig, OmegaConf
2324
from torch import nn
25+
2426
from torchdata.stateful_dataloader import StatefulDataLoader
2527
from torchtitan.components.loss import LossFunction
2628
from torchtitan.components.lr_scheduler import LRSchedulersContainer
@@ -30,6 +32,7 @@
3032
from torchtitan.experiments.forge.job_config import ForgeJobConfig
3133
from tqdm import tqdm
3234

35+
3336
# stubs for now
3437
Checkpointer = Any
3538
Dataloader = Any
@@ -63,7 +66,16 @@ def __init__(self, job_config: ForgeJobConfig):
6366
self.metric_logger = None # TODO: fix this
6467

6568
def setup(self):
66-
self.train_dataloader = self.setup_data()
69+
self.train_dataloader = self.setup_data(
70+
self.job_config.dataset,
71+
batch_size=self.job_config.training.local_batch_size,
72+
)
73+
74+
self.val_dataloader = self.setup_data(
75+
self.job_config.dataset_val,
76+
batch_size=self.job_config.validation.local_batch_size,
77+
)
78+
6779
# self.train_dataloader = self.setup_data(
6880
# self.train_config.train_dataset_config,
6981
# self.train_config.train_dataloader_config,
@@ -79,7 +91,7 @@ def setup(self):
7991
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
8092
# self.logger = self.setup_logger(self.train_config.logger_config)
8193

82-
def setup_data(self):
94+
def setup_data(self, dataset_config, batch_size):
8395
tokenizer = HuggingFaceModelTokenizer(
8496
tokenizer_json_path=os.path.join(
8597
self.job_config.model.hf_assets_path, "tokenizer.json"
@@ -95,8 +107,8 @@ def setup_data(self):
95107
dataset = sft_iterable_dataset(
96108
model_transform=tokenizer,
97109
message_transform=AlpacaToMessages(),
98-
path="yahma/alpaca-cleaned",
99-
split="train",
110+
path=dataset_config.path,
111+
split=dataset_config.split,
100112
)
101113
packer = TextPacker(padding_idx=0)
102114
dataset = PackedDataset(
@@ -106,7 +118,7 @@ def setup_data(self):
106118
)
107119
dataloader = StatefulDataLoader(
108120
dataset=dataset,
109-
batch_size=self.job_config.training.local_batch_size,
121+
batch_size=batch_size,
110122
collate_fn=partial(
111123
collate_packed, mask_fn=packer.create_block_mask, device=self.device
112124
),
@@ -119,7 +131,10 @@ def setup_data(self):
119131
return dataloader
120132

121133
def forward_backward(
122-
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
134+
self,
135+
input_dict: dict[str, torch.Tensor],
136+
labels: torch.Tensor,
137+
do_backward: bool = True,
123138
) -> torch.Tensor:
124139
model_parts = self.model_parts
125140
parallel_dims = self.parallel_dims
@@ -145,14 +160,16 @@ def forward_backward(
145160
targets, losses = (
146161
(labels, []) if self.pp_has_last_stage else (None, None)
147162
)
163+
if do_backward:
164+
pp_schedule_fn = self.pp_schedule.step
165+
else:
166+
pp_schedule_fn = self.pp_schedule.eval
148167
if self.pp_has_first_stage:
149-
self.pp_schedule.step(
168+
pp_schedule_fn(
150169
inputs, target=targets, losses=losses, input_batch=inputs
151170
)
152171
else:
153-
self.pp_schedule.step(
154-
target=targets, losses=losses, input_batch=inputs
155-
)
172+
pp_schedule_fn(target=targets, losses=losses, input_batch=inputs)
156173

157174
# accumulate losses across pipeline microbatches
158175
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
@@ -170,7 +187,8 @@ def forward_backward(
170187
loss = self.loss_fn(pred, labels)
171188
# need to free to before bwd to avoid peaking memory
172189
del pred
173-
loss.backward()
190+
if do_backward:
191+
loss.backward()
174192

175193
return loss
176194

@@ -214,6 +232,52 @@ def train(self) -> None:
214232
last_step=self.current_step == self.num_training_steps,
215233
)
216234

235+
if (
236+
self.job_config.validation.freq > 0
237+
and self.job_config.validation.steps > 0
238+
and self.current_step % self.job_config.validation.freq == 0
239+
):
240+
self.validate(self.job_config.validation.steps)
241+
242+
def validate(self, max_steps: int) -> None:
243+
for m in self.model_parts:
244+
m.eval()
245+
total_val_loss = torch.tensor(0.0, device=self.device)
246+
total_val_tokens = torch.tensor(0.0, device=self.device)
247+
with torch.no_grad():
248+
val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False)
249+
for batch_idx, batch in enumerate(val_pbar):
250+
if batch_idx >= max_steps:
251+
break
252+
batch_to_device(batch, self.device)
253+
current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum()
254+
# Compute loss
255+
labels = batch.pop("labels")
256+
loss = self.forward_backward(batch, labels, do_backward=False)
257+
val_loss = loss * current_num_tokens
258+
total_val_loss += val_loss
259+
total_val_tokens += current_num_tokens
260+
# Update progress bar description with current average loss
261+
avg_loss_so_far = (
262+
(total_val_loss / total_val_tokens).item()
263+
if total_val_tokens > 0
264+
else float("inf")
265+
)
266+
val_pbar.set_description(
267+
f"Running validation Loss: {avg_loss_so_far:.4f}"
268+
)
269+
# Aggregate validation metrics across all ranks
270+
torch.distributed.all_reduce(total_val_loss)
271+
torch.distributed.all_reduce(total_val_tokens)
272+
avg_val_loss = (
273+
(total_val_loss / total_val_tokens).item()
274+
if total_val_tokens > 0
275+
else float("inf")
276+
)
277+
for m in self.model_parts:
278+
m.train()
279+
print(f"\nValidation loss: {avg_val_loss}")
280+
217281
def cleanup(self) -> None:
218282
if self.checkpointer:
219283
self.checkpointer.close()

src/forge/actors/policy.py

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,13 @@
1515
import torch
1616

1717
from monarch.actor import Actor, current_rank, endpoint, proc_mesh
18+
from torchstore import MultiProcessStore
19+
20+
from torchstore._state_dict_utils import DELIM
1821

1922
from vllm.engine.arg_utils import EngineArgs
2023
from vllm.entrypoints.utils import _validate_truncation_size
2124
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
22-
from vllm.inputs import TextPrompt, TokensPrompt
2325
from vllm.lora.request import LoRARequest
2426
from vllm.outputs import CompletionOutput
2527
from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams
@@ -37,6 +39,8 @@
3739
from vllm.v1.structured_output import StructuredOutputManager
3840
from vllm.worker.worker_base import WorkerWrapperBase
3941

42+
from forge.data.sharding import VLLMSharding
43+
4044
logger = logging.getLogger(__name__)
4145

4246

@@ -194,6 +198,7 @@ class Policy(Actor):
194198
enforce_eager: bool = False
195199
vllm_args: EngineArgs = None
196200
resources: int = 1
201+
state_dict_key: str = "model_state_dict"
197202

198203
def __post_init__(self):
199204
"""Build vLLM Arguments
@@ -238,7 +243,8 @@ def __post_init__(self):
238243
assert self.vllm_args.parallel_config.world_size == self.resources
239244

240245
@endpoint
241-
async def setup(self):
246+
async def setup(self, store: MultiProcessStore = None):
247+
self.torchstore = store
242248
# TODO: remove ["gpus"] when monarch implements a flat rank
243249
self.rank = current_rank()["gpus"]
244250
self.worker = self.setup_worker()
@@ -247,10 +253,50 @@ async def setup(self):
247253
async def execute_model(self, schedule: SchedulerOutput):
248254
return self.worker.execute_model(schedule)
249255

256+
async def _load_tensor_parallel_state_dict(self, current_state_dict: dict):
257+
"""
258+
Load full state dict from torchstore into tensor parallel model with deterministic sharding.
259+
"""
260+
261+
updated_count = 0
262+
# setting explictly to llama3 for now as its our only use case
263+
sharding = VLLMSharding(self.tensor_parallel_size, self.rank)
264+
265+
for param_name in current_state_dict.keys():
266+
current_tensor = current_state_dict[param_name]
267+
268+
# Load the full tensor from torchstore
269+
# TODO: only get the part of the tensor that is needed
270+
stored_tensor = await self.torchstore.get(
271+
f"{self.state_dict_key}{DELIM}{param_name}"
272+
)
273+
sharding.load_from_source_to_target(
274+
param_name,
275+
stored_tensor,
276+
current_tensor,
277+
)
278+
279+
updated_count += 1
280+
250281
@endpoint
251282
async def update(self):
252-
# TODO: add TorchStore support
253-
pass
283+
"""Update model weights by reading state dict from torchstore"""
284+
285+
if self.torchstore is None:
286+
raise Exception("No torchstore configured, skipping model update")
287+
288+
logger.debug(
289+
f"Starting model update from torchstore with key: {self.state_dict_key}"
290+
)
291+
292+
model = self.worker.model_runner.model
293+
current_state_dict = model.state_dict()
294+
295+
logger.debug(f"Current state dict has {len(current_state_dict)} parameters")
296+
297+
await self._load_tensor_parallel_state_dict(current_state_dict)
298+
299+
logger.debug("Successfully updated model weights from torchstore")
254300

255301
@endpoint
256302
async def setup_kv_cache(self):
@@ -286,6 +332,17 @@ async def setup_kv_cache(self):
286332
async def get_vllm_args(self):
287333
return self.vllm_args
288334

335+
@endpoint
336+
async def get_model_params(self):
337+
model = self.worker.model_runner.model
338+
state_dict = {}
339+
340+
for name, param in model.named_parameters():
341+
if "layers.0" not in name:
342+
continue
343+
state_dict[name] = param.cpu().detach()
344+
return state_dict
345+
289346
def setup_worker(self):
290347
"""Build and Instantiate vLLM worker"""
291348
parallel_config = self.vllm_args.parallel_config

0 commit comments

Comments
 (0)