Skip to content

Commit 8f46e74

Browse files
author
Hossein Kavianihamedani
committed
Add evaluation functionality to SFT notebook utilities
- Added setup_eval_dataloaders() function to utils.py for multi-dataset evaluation - Added evaluate() method to TrainerActor for periodic and final evaluation - Added forward_backward_eval() for evaluation forward passes (no backprop) - Evaluation supports: - Multiple eval datasets - Periodic evaluation during training (eval_every_n_steps) - Final evaluation at end of training - Macro/micro average loss across datasets - StopAfterOneEpoch for proper epoch boundaries - max_eval_steps cap support - Fixed docstring to comply with pydoclint - Now matches full evaluation capabilities from main.py
1 parent 8774767 commit 8f46e74

File tree

2 files changed

+250
-6
lines changed

2 files changed

+250
-6
lines changed

apps/sft/trainer_actor.py

Lines changed: 197 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
This is a concrete implementation of BaseForgeActor for supervised fine-tuning.
1111
"""
1212

13+
import contextlib
1314
import logging
1415

1516
import torch
@@ -19,9 +20,11 @@
1920
create_context_parallel_context,
2021
log_training_step,
2122
move_batch_to_device,
23+
setup_eval_dataloaders,
2224
setup_sft_dataloader,
2325
setup_tokenizer,
2426
)
27+
from forge.data.utils import StopAfterOneEpoch
2528
from monarch.actor import endpoint
2629
from omegaconf import DictConfig
2730

@@ -34,19 +37,16 @@ class TrainerActor(BaseForgeActor):
3437
Concrete trainer actor for supervised fine-tuning.
3538
3639
Handles training loop, forward/backward passes, and checkpoint management.
40+
41+
Args:
42+
config: Configuration dictionary containing training settings
3743
"""
3844

3945
train_spec: forge_train_spec.ForgeTrainSpec
4046
train_dataloader: any
4147
num_training_steps: int
4248

4349
def __init__(self, config: DictConfig):
44-
"""
45-
Initialize the trainer actor.
46-
47-
Args:
48-
config: Configuration dictionary containing training settings
49-
"""
5050
super().__init__(config)
5151
self.num_training_steps = self.job_config.training.steps
5252

@@ -61,6 +61,7 @@ async def setup(self):
6161
hf_assets_path=self.job_config.model.hf_assets_path
6262
)
6363

64+
# Setup training dataloader
6465
self.train_dataloader = setup_sft_dataloader(
6566
tokenizer=self.tokenizer,
6667
dataset_path="yahma/alpaca-cleaned",
@@ -70,6 +71,31 @@ async def setup(self):
7071
device=self.device,
7172
)
7273

74+
# Setup evaluation dataloaders if configured
75+
eval_config = self.job_config.get("eval", {})
76+
self.val_dataloaders = {}
77+
self.eval_every_n_steps = eval_config.get("eval_every_n_steps")
78+
max_eval_steps = eval_config.get("max_eval_steps")
79+
self.max_eval_steps = (
80+
max_eval_steps if max_eval_steps and max_eval_steps > 0 else None
81+
)
82+
self.validation_enabled = (
83+
self.eval_every_n_steps is not None and self.eval_every_n_steps > 0
84+
)
85+
86+
if self.validation_enabled:
87+
logger.info("Setting up eval datasets...")
88+
eval_datasets_config = eval_config.get("datasets", [])
89+
self.val_dataloaders = setup_eval_dataloaders(
90+
tokenizer=self.tokenizer,
91+
eval_datasets_config=eval_datasets_config,
92+
target_tokens_per_pack=self.job_config.training.seq_len,
93+
batch_size=self.job_config.training.local_batch_size,
94+
device=self.device,
95+
)
96+
logger.info(f"Loaded {len(self.val_dataloaders)} eval datasets")
97+
98+
# Load checkpoint if exists
7399
if self.checkpointer:
74100
logger.info("Loading checkpoint...")
75101
self.checkpointer.load(step=self.current_step)
@@ -163,14 +189,179 @@ async def run(self) -> None:
163189
self.train_step(batch)
164190
self.current_step += 1
165191

192+
# Run evaluation periodically if enabled
193+
if (
194+
self.validation_enabled
195+
and self.current_step % self.eval_every_n_steps == 0
196+
):
197+
await self.evaluate()
198+
166199
if self.checkpointer:
167200
self.checkpointer.save(
168201
curr_step=self.current_step,
169202
last_step=self.current_step == self.num_training_steps,
170203
)
171204

205+
# Final evaluation
206+
if self.validation_enabled:
207+
logger.info("Running final evaluation at end of training...")
208+
await self.evaluate()
209+
172210
logger.info("Training complete!")
173211

212+
async def evaluate(self) -> None:
213+
"""
214+
Run evaluation on multiple datasets, one at a time.
215+
216+
1. Set models to eval mode
217+
2. For each eval dataset:
218+
- Create fresh iterator (starts from epoch 0)
219+
- Use StopAfterOneEpoch to iterate until epoch boundary
220+
- Respect max_eval_steps cap if configured
221+
- Record loss and step metrics
222+
3. Restore models to train mode
223+
"""
224+
logger.info("==Starting evaluation==")
225+
226+
# Set models to eval mode
227+
for model_part in self.model_parts:
228+
model_part.eval()
229+
230+
# Get DP mesh for epoch synchronization
231+
dp_mesh = None
232+
if self.parallel_dims is not None and self.parallel_dims.dp_enabled:
233+
dp_mesh = self.parallel_dims.world_mesh.get_group("dp")
234+
235+
# For non-PP: disable gradients to save memory
236+
maybe_no_grad = (
237+
contextlib.nullcontext()
238+
if self.parallel_dims.pp_enabled
239+
else torch.no_grad()
240+
)
241+
242+
# Evaluate each dataset sequentially
243+
all_dataset_losses = []
244+
all_dataset_steps = []
245+
246+
for dataset_name, val_dataloader in self.val_dataloaders.items():
247+
logger.info(f"=====Evaluating dataset: {dataset_name}=====")
248+
249+
total_loss = torch.tensor(0.0, device=self.device)
250+
num_steps = 0
251+
252+
# NOTE: Assumes batch contains field "metrics" containing "num_epochs"
253+
batch_iter = StopAfterOneEpoch(
254+
iter=iter(val_dataloader), # Fresh iterator from epoch 0
255+
device=self.device,
256+
dp_mesh=dp_mesh,
257+
)
258+
259+
with maybe_no_grad:
260+
for batch in batch_iter:
261+
# If max_eval_steps>len(dataset), it will be stopped earlier
262+
if (
263+
self.max_eval_steps is not None
264+
and num_steps >= self.max_eval_steps
265+
):
266+
logger.info(
267+
f"[{dataset_name}] Reached max_eval_steps cap of {self.max_eval_steps}"
268+
)
269+
break
270+
271+
# Move batch to device
272+
batch = move_batch_to_device(batch, self.device)
273+
274+
# Forward pass only (no backward)
275+
labels = batch.pop("labels")
276+
loss = self.forward_backward_eval(batch, labels)
277+
total_loss += loss
278+
num_steps += 1
279+
280+
logger.info(
281+
f"[dataset {dataset_name}] Step {num_steps} | Loss: {loss.item():.4f}"
282+
)
283+
284+
# Log average loss for this dataset
285+
avg_loss = (total_loss / max(num_steps, 1)).item()
286+
all_dataset_losses.append(avg_loss)
287+
all_dataset_steps.append(num_steps)
288+
logger.info(
289+
f"[dataset {dataset_name}] Final Step {num_steps} | Avg Loss: {avg_loss:.4f}"
290+
)
291+
292+
# Record macro and micro average losses across datasets
293+
if len(all_dataset_losses) > 1:
294+
# Macro: same weight for all datasets
295+
macro_avg_loss = sum(all_dataset_losses) / len(all_dataset_losses)
296+
logger.info(f"Macro avg loss (unweighted): {macro_avg_loss:.4f}")
297+
298+
# Micro: weighted mean by dataset size
299+
total_steps = sum(all_dataset_steps)
300+
micro_avg_loss = (
301+
sum(
302+
loss * steps
303+
for loss, steps in zip(all_dataset_losses, all_dataset_steps)
304+
)
305+
/ total_steps
306+
)
307+
logger.info(f"Micro avg loss (weighted): {micro_avg_loss:.4f}")
308+
309+
# Restore train mode
310+
for model_part in self.model_parts:
311+
model_part.train()
312+
313+
logger.info("==Evaluation complete==")
314+
315+
def forward_backward_eval(
316+
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
317+
) -> torch.Tensor:
318+
"""
319+
Perform forward pass only (for evaluation).
320+
321+
Args:
322+
input_dict: Dictionary containing input tokens
323+
labels: Ground truth labels
324+
325+
Returns:
326+
Computed loss value
327+
"""
328+
model_parts = self.model_parts
329+
parallel_dims = self.parallel_dims
330+
inputs = input_dict["tokens"]
331+
332+
optional_context_parallel_ctx = create_context_parallel_context(
333+
parallel_dims=parallel_dims,
334+
inputs=inputs,
335+
labels=labels,
336+
model_parts=model_parts,
337+
rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
338+
)
339+
340+
if parallel_dims.pp_enabled:
341+
with self.train_context(optional_context_parallel_ctx):
342+
targets, losses = (
343+
(labels, []) if self.pp_has_last_stage else (None, None)
344+
)
345+
if self.pp_has_first_stage:
346+
self.pp_schedule.step(inputs, target=targets, losses=losses)
347+
else:
348+
self.pp_schedule.step(target=targets, losses=losses)
349+
350+
loss = (
351+
torch.sum(torch.stack(losses)).to(self.device)
352+
if self.pp_has_last_stage
353+
else torch.tensor(-1.0, device=self.device)
354+
)
355+
else:
356+
with self.train_context(optional_context_parallel_ctx):
357+
assert len(model_parts) == 1
358+
with self.maybe_enable_amp:
359+
pred = model_parts[0](inputs)
360+
loss = self.loss_fn(pred, labels)
361+
del pred
362+
363+
return loss
364+
174365
@endpoint
175366
async def cleanup(self) -> None:
176367
"""

apps/sft/utils.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,3 +185,56 @@ def log_training_step(
185185
logger: Logger instance
186186
"""
187187
logger.info(f"Step {step}/{total_steps} | Loss: {loss.item():.4f}")
188+
189+
190+
def setup_eval_dataloaders(
191+
tokenizer: HuggingFaceModelTokenizer,
192+
eval_datasets_config: list[dict],
193+
target_tokens_per_pack: int,
194+
batch_size: int,
195+
device: torch.device,
196+
padding_idx: int = 0,
197+
message_transform: Optional[Any] = None,
198+
dp_mesh: Optional[Any] = None,
199+
) -> dict[str, StatefulDataLoader]:
200+
"""
201+
Setup multiple evaluation dataloaders from config.
202+
203+
Args:
204+
tokenizer: Tokenizer to use for processing text
205+
eval_datasets_config: List of eval dataset configurations
206+
target_tokens_per_pack: Target sequence length for packing
207+
batch_size: Batch size for evaluation
208+
device: Device to move tensors to
209+
padding_idx: Padding token index
210+
message_transform: Transform to convert dataset format to messages
211+
dp_mesh: Data parallel mesh for distributed evaluation
212+
213+
Returns:
214+
Dictionary mapping dataset names to their dataloaders
215+
"""
216+
if message_transform is None:
217+
message_transform = AlpacaToMessages()
218+
219+
val_dataloaders = {}
220+
221+
for i, dataset_config in enumerate(eval_datasets_config):
222+
ds_name = dataset_config.get("dataset_name", i)
223+
224+
logger.info(f"Loading eval dataset: {ds_name}")
225+
226+
# Use the same setup_sft_dataloader but with eval dataset config
227+
dataloader = setup_sft_dataloader(
228+
tokenizer=tokenizer,
229+
dataset_path=dataset_config["path"],
230+
dataset_split=dataset_config["split"],
231+
target_tokens_per_pack=target_tokens_per_pack,
232+
batch_size=batch_size,
233+
device=device,
234+
padding_idx=padding_idx,
235+
message_transform=message_transform,
236+
)
237+
238+
val_dataloaders[ds_name] = dataloader
239+
240+
return val_dataloaders

0 commit comments

Comments
 (0)