Skip to content

Commit a0f62e7

Browse files
author
Hossein Kavianihamedani
committed
Adding eval loop to the sft
1 parent 7550664 commit a0f62e7

File tree

8 files changed

+139
-2165
lines changed

8 files changed

+139
-2165
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ training:
3333
steps: 1000
3434
compile: false
3535
dataset: "c4"
36+
#eval_interval: 500 # Setting eval_interval to run evaluation
37+
#eval_steps: 100 # Number of validation batches during each evaluation run
3638

3739
parallelism:
3840
data_parallel_replicate_degree: 1

apps/sft/main.py

Lines changed: 137 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
"""To run:
88
99
python -m apps.sft.main --config apps/sft/llama3_8b.yaml
10-
1110
"""
1211

1312
import asyncio
@@ -40,8 +39,6 @@
4039
from torchtitan.experiments.forge.engine import ForgeEngine
4140
from torchtitan.experiments.forge.job_config import ForgeJobConfig
4241

43-
# from tqdm import tqdm
44-
4542
# stubs for now
4643
Checkpointer = Any
4744
Dataloader = Any
@@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine):
6461
checkpointer: Checkpointer
6562
tokenizer: Tokenizer
6663
train_dataloader: Dataloader
67-
# val_dataloader: Dataloader
64+
val_dataloader: Dataloader
6865
metric_logger: MetricLogger
6966
profiler: Profiler
7067
device: torch.device
@@ -81,6 +78,11 @@ def __init__(self, config: DictConfig):
8178
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
8279
self._rank = current_rank().rank
8380
self._size = math.prod(current_size().values())
81+
82+
# Evaluation settings
83+
self.eval_interval = job_config.training.get("eval_interval", float("inf"))
84+
self.eval_steps = job_config.training.get("eval_steps", 0)
85+
8486
self._init_dist()
8587
super().__init__(job_config)
8688

@@ -111,25 +113,23 @@ def _init_dist(self):
111113

112114
@endpoint
113115
async def setup(self):
114-
self.train_dataloader = self.setup_data()
115-
# self.train_dataloader = self.setup_data(
116-
# self.train_config.train_dataset_config,
117-
# self.train_config.train_dataloader_config,
118-
# self.train_config.packing_config,
119-
# )
120-
# self.val_dataloader = self.setup_data(
121-
# self.train_config.val_dataset_config,
122-
# self.train_config.val_dataloader_config,
123-
# self.train_config.packing_config,
124-
# )
125-
126-
# TODO: confirm that this is working properly
127-
# Should also use load, not dcp_load
116+
# Setup training data (first 90% of train split)
117+
self.train_dataloader = self.setup_data(
118+
dataset_path="yahma/alpaca-cleaned", dataset_split="train[:90%]"
119+
)
120+
121+
# Setup validation data (last 10% of train split)
122+
self.val_dataloader = self.setup_data(
123+
dataset_path="yahma/alpaca-cleaned", dataset_split="train[90%:]"
124+
)
125+
126+
# Load checkpoint if resuming
128127
self.checkpointer.load(step=self.current_step)
129-
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
130-
# self.logger = self.setup_logger(self.train_config.logger_config)
131128

132-
def setup_data(self):
129+
def setup_data(
130+
self, dataset_path: str = "yahma/alpaca-cleaned", dataset_split: str = "train"
131+
):
132+
"""Setup data with configurable dataset path and split."""
133133
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
134134
tokenizer = HuggingFaceModelTokenizer(
135135
tokenizer_json_path=os.path.join(
@@ -146,8 +146,8 @@ def setup_data(self):
146146
dataset = sft_iterable_dataset(
147147
model_transform=tokenizer,
148148
message_transform=AlpacaToMessages(),
149-
path="yahma/alpaca-cleaned",
150-
split="train",
149+
path=dataset_path,
150+
split=dataset_split,
151151
)
152152
packer = TextPacker(padding_idx=0)
153153
dataset = PackedDataset(
@@ -163,10 +163,6 @@ def setup_data(self):
163163
),
164164
)
165165

166-
# Ultimately we probably want something like this
167-
# packer = build_packing_strategy(packing_config)
168-
# dataset = build_dataset(dataset_config)
169-
# dataloader = build_dataloader(dataloader_config, dataset, packer)
170166
return dataloader
171167

172168
def forward_backward(
@@ -206,7 +202,6 @@ def forward_backward(
206202
)
207203

208204
# accumulate losses across pipeline microbatches
209-
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
210205
loss = (
211206
torch.mean(torch.stack(losses)).to(self.device)
212207
if self.pp_has_last_stage
@@ -225,27 +220,125 @@ def forward_backward(
225220

226221
return loss
227222

223+
def forward_only(
224+
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
225+
) -> torch.Tensor:
226+
"""Forward pass only for evaluation (no backward)."""
227+
model_parts = self.model_parts
228+
parallel_dims = self.parallel_dims
229+
230+
inputs = input_dict["tokens"]
231+
optional_context_parallel_ctx = (
232+
dist_utils.create_context_parallel_ctx(
233+
cp_mesh=parallel_dims.world_mesh["cp"],
234+
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
235+
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
236+
cp_no_restore_buffers={inputs, labels},
237+
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
238+
)
239+
if parallel_dims.cp_enabled
240+
else None
241+
)
242+
243+
if parallel_dims.pp_enabled:
244+
# Pipeline Parallel forward only
245+
with self.train_context(optional_context_parallel_ctx):
246+
targets, losses = (
247+
(labels, []) if self.pp_has_last_stage else (None, None)
248+
)
249+
if self.pp_has_first_stage:
250+
self.pp_schedule.step(
251+
inputs, target=targets, losses=losses, input_batch=inputs
252+
)
253+
else:
254+
self.pp_schedule.step(
255+
target=targets, losses=losses, input_batch=inputs
256+
)
257+
258+
loss = (
259+
torch.mean(torch.stack(losses)).to(self.device)
260+
if self.pp_has_last_stage
261+
else torch.tensor([-1.0], device=self.device)
262+
)
263+
else:
264+
# Non-PP forward only
265+
with self.train_context(optional_context_parallel_ctx):
266+
assert len(model_parts) == 1
267+
with self.maybe_enable_amp:
268+
pred = model_parts[0](inputs)
269+
loss = self.loss_fn(pred, labels)
270+
del pred
271+
272+
return loss
273+
228274
def train_step(self, batch) -> None:
229-
# TODO
230-
# with GradientAccumulation(
231-
# self.gradient_accumulation_steps,
232-
# self.model,
233-
# self.data_parallel_size,
234-
# ) as grad_acc:
235275
labels = batch.pop("labels")
236276
loss = self.forward_backward(batch, labels)
237277

238278
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
239-
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
240-
# self.pbar.update(1)
241279
self.optimizers.step()
242280
self.lr_schedulers.step()
243281

282+
async def evaluate(self) -> dict[str, float]:
283+
"""Run evaluation on validation set (internal method, not an endpoint)."""
284+
logger.info("=" * 50)
285+
logger.info("STARTING EVALUATION ")
286+
logger.info("=" * 50)
287+
288+
# Set model to eval mode
289+
for model_part in self.model_parts:
290+
model_part.eval()
291+
292+
val_dataloader = iter(self.val_dataloader)
293+
total_loss = 0.0
294+
num_batches = 0
295+
296+
with torch.no_grad():
297+
for step in range(self.eval_steps):
298+
try:
299+
batch = next(val_dataloader)
300+
301+
# Move tensors to device
302+
for k, v in batch.items():
303+
if isinstance(v, torch.Tensor):
304+
batch[k] = v.to(self.device)
305+
306+
labels = batch.pop("labels")
307+
loss = self.forward_only(batch, labels)
308+
309+
total_loss += loss.item()
310+
num_batches += 1
311+
312+
logger.info(
313+
f" Eval batch {num_batches}/{self.eval_steps} | Loss: {loss.item():.4f}"
314+
)
315+
316+
except StopIteration:
317+
logger.warning("Reached end of validation dataloader early")
318+
break
319+
320+
# Set model back to train mode
321+
for model_part in self.model_parts:
322+
model_part.train()
323+
324+
avg_loss = total_loss / max(num_batches, 1)
325+
326+
metrics = {
327+
"val_loss": avg_loss,
328+
"val_batches": num_batches,
329+
}
330+
331+
logger.info("-" * 50)
332+
logger.info(f"EVALUATION COMPLETE")
333+
logger.info(f"Validation Loss: {avg_loss:.4f}")
334+
logger.info(f"Batches Evaluated: {num_batches}")
335+
logger.info("=" * 50)
336+
return metrics
337+
244338
@endpoint
245339
async def train(self) -> None:
246340
dataloader = iter(self.train_dataloader)
247341
self.optimizers.zero_grad()
248-
249342
# TODO: tqdm is broken in Monarch actors
250343
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)
251344

@@ -254,18 +347,21 @@ async def train(self) -> None:
254347
# Move tensors to the appropriate device
255348
for k, v in batch.items():
256349
if isinstance(v, torch.Tensor):
257-
batch[k] = v.to("cuda") # TODO: hardcoded for now
350+
batch[k] = v.to(self.device) # TODO: hardcoded for now
258351
self.train_step(batch)
259-
# self.profiler.step()
260352
self.current_step += 1
261353

354+
# Run evaluation periodically
355+
if self.current_step % self.eval_interval == 0:
356+
eval_metrics = await self.evaluate()
357+
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}")
358+
359+
# Save checkpoints
262360
self.checkpointer.save(
263361
curr_step=self.current_step,
264362
last_step=self.current_step == self.num_training_steps,
265363
)
266364

267-
# self.pbar.close()
268-
269365
@endpoint
270366
async def cleanup(self) -> None:
271367
if self.checkpointer:

0 commit comments

Comments
 (0)