Skip to content

Commit b38c936

Browse files
merge
2 parents 903204a + ce1ed98 commit b38c936

File tree

20 files changed

+2440
-1397
lines changed

20 files changed

+2440
-1397
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,19 @@ training:
3131
max_norm: 1.0
3232
steps: 1000
3333
compile: false
34-
dataset: "c4"
34+
35+
validation:
36+
local_batch_size: 1
37+
freq: -1 # Change to a positive number to enable validation
38+
steps: 200 # Max steps to run validation. Validation disabled if negative.
39+
40+
dataset:
41+
path: yahma/alpaca-cleaned
42+
split: train[:95%]
43+
44+
dataset_val:
45+
path: yahma/alpaca-cleaned
46+
split: train[95%:]
3547

3648
parallelism:
3749
data_parallel_replicate_degree: 1

apps/sft/main.py

Lines changed: 75 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from forge.data.datasets.packed import PackedDataset, TextPacker
1919
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
2020
from forge.data.tokenizer import HuggingFaceModelTokenizer
21+
from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX
2122
from forge.util import get_metric_logger
2223

2324
from omegaconf import DictConfig, OmegaConf
2425
from torch import nn
26+
2527
from torchdata.stateful_dataloader import StatefulDataLoader
2628
from torchtitan.components.loss import LossFunction
2729
from torchtitan.components.lr_scheduler import LRSchedulersContainer
@@ -31,6 +33,7 @@
3133
from torchtitan.experiments.forge.job_config import ForgeJobConfig
3234
from tqdm import tqdm
3335

36+
3437
# stubs for now
3538
Checkpointer = Any
3639
Dataloader = Any
@@ -64,7 +67,16 @@ def __init__(self, job_config: ForgeJobConfig):
6467
self.metric_logger = get_metric_logger(**job_config.metrics)
6568

6669
def setup(self):
67-
self.train_dataloader = self.setup_data()
70+
self.train_dataloader = self.setup_data(
71+
self.job_config.dataset,
72+
batch_size=self.job_config.training.local_batch_size,
73+
)
74+
75+
self.val_dataloader = self.setup_data(
76+
self.job_config.dataset_val,
77+
batch_size=self.job_config.validation.local_batch_size,
78+
)
79+
6880
# self.train_dataloader = self.setup_data(
6981
# self.train_config.train_dataset_config,
7082
# self.train_config.train_dataloader_config,
@@ -80,7 +92,7 @@ def setup(self):
8092
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
8193
# self.logger = self.setup_logger(self.train_config.logger_config)
8294

83-
def setup_data(self):
95+
def setup_data(self, dataset_config, batch_size):
8496
tokenizer = HuggingFaceModelTokenizer(
8597
tokenizer_json_path=os.path.join(
8698
self.job_config.model.hf_assets_path, "tokenizer.json"
@@ -96,8 +108,8 @@ def setup_data(self):
96108
dataset = sft_iterable_dataset(
97109
model_transform=tokenizer,
98110
message_transform=AlpacaToMessages(),
99-
path="yahma/alpaca-cleaned",
100-
split="train",
111+
path=dataset_config.path,
112+
split=dataset_config.split,
101113
)
102114
packer = TextPacker(padding_idx=0)
103115
dataset = PackedDataset(
@@ -107,7 +119,7 @@ def setup_data(self):
107119
)
108120
dataloader = StatefulDataLoader(
109121
dataset=dataset,
110-
batch_size=self.job_config.training.local_batch_size,
122+
batch_size=batch_size,
111123
collate_fn=partial(
112124
collate_packed, mask_fn=packer.create_block_mask, device=self.device
113125
),
@@ -120,7 +132,10 @@ def setup_data(self):
120132
return dataloader
121133

122134
def forward_backward(
123-
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
135+
self,
136+
input_dict: dict[str, torch.Tensor],
137+
labels: torch.Tensor,
138+
do_backward: bool = True,
124139
) -> torch.Tensor:
125140
model_parts = self.model_parts
126141
parallel_dims = self.parallel_dims
@@ -146,14 +161,16 @@ def forward_backward(
146161
targets, losses = (
147162
(labels, []) if self.pp_has_last_stage else (None, None)
148163
)
164+
if do_backward:
165+
pp_schedule_fn = self.pp_schedule.step
166+
else:
167+
pp_schedule_fn = self.pp_schedule.eval
149168
if self.pp_has_first_stage:
150-
self.pp_schedule.step(
169+
pp_schedule_fn(
151170
inputs, target=targets, losses=losses, input_batch=inputs
152171
)
153172
else:
154-
self.pp_schedule.step(
155-
target=targets, losses=losses, input_batch=inputs
156-
)
173+
pp_schedule_fn(target=targets, losses=losses, input_batch=inputs)
157174

158175
# accumulate losses across pipeline microbatches
159176
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
@@ -171,7 +188,8 @@ def forward_backward(
171188
loss = self.loss_fn(pred, labels)
172189
# need to free to before bwd to avoid peaking memory
173190
del pred
174-
loss.backward()
191+
if do_backward:
192+
loss.backward()
175193

176194
return loss
177195

@@ -216,6 +234,52 @@ def train(self) -> None:
216234
last_step=self.current_step == self.num_training_steps,
217235
)
218236

237+
if (
238+
self.job_config.validation.freq > 0
239+
and self.job_config.validation.steps > 0
240+
and self.current_step % self.job_config.validation.freq == 0
241+
):
242+
self.validate(self.job_config.validation.steps)
243+
244+
def validate(self, max_steps: int) -> None:
245+
for m in self.model_parts:
246+
m.eval()
247+
total_val_loss = torch.tensor(0.0, device=self.device)
248+
total_val_tokens = torch.tensor(0.0, device=self.device)
249+
with torch.no_grad():
250+
val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False)
251+
for batch_idx, batch in enumerate(val_pbar):
252+
if batch_idx >= max_steps:
253+
break
254+
batch_to_device(batch, self.device)
255+
current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum()
256+
# Compute loss
257+
labels = batch.pop("labels")
258+
loss = self.forward_backward(batch, labels, do_backward=False)
259+
val_loss = loss * current_num_tokens
260+
total_val_loss += val_loss
261+
total_val_tokens += current_num_tokens
262+
# Update progress bar description with current average loss
263+
avg_loss_so_far = (
264+
(total_val_loss / total_val_tokens).item()
265+
if total_val_tokens > 0
266+
else float("inf")
267+
)
268+
val_pbar.set_description(
269+
f"Running validation Loss: {avg_loss_so_far:.4f}"
270+
)
271+
# Aggregate validation metrics across all ranks
272+
torch.distributed.all_reduce(total_val_loss)
273+
torch.distributed.all_reduce(total_val_tokens)
274+
avg_val_loss = (
275+
(total_val_loss / total_val_tokens).item()
276+
if total_val_tokens > 0
277+
else float("inf")
278+
)
279+
for m in self.model_parts:
280+
m.train()
281+
print(f"\nValidation loss: {avg_val_loss}")
282+
219283
def cleanup(self) -> None:
220284
if self.checkpointer:
221285
self.checkpointer.close()

0 commit comments

Comments
 (0)