-
Notifications
You must be signed in to change notification settings - Fork 16
Adding eval to the SFT #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
3653453
baeb35b
7550664
a0f62e7
53371c6
4793948
676db88
250c0cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,6 @@ | |
"""To run: | ||
|
||
python -m apps.sft.main --config apps/sft/llama3_8b.yaml | ||
|
||
""" | ||
|
||
import asyncio | ||
|
@@ -40,8 +39,6 @@ | |
from torchtitan.experiments.forge.engine import ForgeEngine | ||
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
|
||
# from tqdm import tqdm | ||
|
||
# stubs for now | ||
Checkpointer = Any | ||
Dataloader = Any | ||
|
@@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine): | |
checkpointer: Checkpointer | ||
tokenizer: Tokenizer | ||
train_dataloader: Dataloader | ||
# val_dataloader: Dataloader | ||
val_dataloader: Dataloader | ||
metric_logger: MetricLogger | ||
profiler: Profiler | ||
device: torch.device | ||
|
@@ -81,6 +78,11 @@ def __init__(self, config: DictConfig): | |
self.gradient_accumulation_steps = 1 # Example value, adjust as needed | ||
self._rank = current_rank().rank | ||
self._size = math.prod(current_size().values()) | ||
|
||
# Evaluation settings | ||
self.eval_interval = job_config.training.get("eval_interval", float("inf")) | ||
self.eval_steps = job_config.training.get("eval_steps", 0) | ||
|
||
self._init_dist() | ||
super().__init__(job_config) | ||
|
||
|
@@ -111,25 +113,23 @@ def _init_dist(self): | |
|
||
@endpoint | ||
async def setup(self): | ||
self.train_dataloader = self.setup_data() | ||
# self.train_dataloader = self.setup_data( | ||
# self.train_config.train_dataset_config, | ||
# self.train_config.train_dataloader_config, | ||
# self.train_config.packing_config, | ||
# ) | ||
# self.val_dataloader = self.setup_data( | ||
# self.train_config.val_dataset_config, | ||
# self.train_config.val_dataloader_config, | ||
# self.train_config.packing_config, | ||
# ) | ||
|
||
# TODO: confirm that this is working properly | ||
# Should also use load, not dcp_load | ||
# Setup training data (first 90% of train split) | ||
self.train_dataloader = self.setup_data( | ||
dataset_path="yahma/alpaca-cleaned", dataset_split="train[:90%]" | ||
HosseinKaviani-H marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
) | ||
|
||
# Setup validation data (last 10% of train split) | ||
self.val_dataloader = self.setup_data( | ||
dataset_path="yahma/alpaca-cleaned", dataset_split="train[90%:]" | ||
) | ||
|
||
# Load checkpoint if resuming | ||
self.checkpointer.load(step=self.current_step) | ||
# self.profiler = self.setup_profiler(self.train_config.profiler_config) | ||
# self.logger = self.setup_logger(self.train_config.logger_config) | ||
|
||
def setup_data(self): | ||
def setup_data( | ||
self, dataset_path: str = "yahma/alpaca-cleaned", dataset_split: str = "train" | ||
HosseinKaviani-H marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
): | ||
"""Setup data with configurable dataset path and split.""" | ||
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) | ||
tokenizer = HuggingFaceModelTokenizer( | ||
tokenizer_json_path=os.path.join( | ||
|
@@ -146,8 +146,8 @@ def setup_data(self): | |
dataset = sft_iterable_dataset( | ||
model_transform=tokenizer, | ||
message_transform=AlpacaToMessages(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. We can have this implemented as soon as we fix the main eval functionality |
||
path="yahma/alpaca-cleaned", | ||
split="train", | ||
path=dataset_path, | ||
split=dataset_split, | ||
) | ||
packer = TextPacker(padding_idx=0) | ||
dataset = PackedDataset( | ||
|
@@ -163,10 +163,6 @@ def setup_data(self): | |
), | ||
) | ||
|
||
# Ultimately we probably want something like this | ||
# packer = build_packing_strategy(packing_config) | ||
# dataset = build_dataset(dataset_config) | ||
# dataloader = build_dataloader(dataloader_config, dataset, packer) | ||
return dataloader | ||
|
||
def forward_backward( | ||
|
@@ -206,7 +202,6 @@ def forward_backward( | |
) | ||
|
||
# accumulate losses across pipeline microbatches | ||
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
loss = ( | ||
torch.mean(torch.stack(losses)).to(self.device) | ||
if self.pp_has_last_stage | ||
|
@@ -225,27 +220,213 @@ def forward_backward( | |
|
||
return loss | ||
|
||
def forward_only( | ||
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Forward pass only for evaluation (no backward).""" | ||
model_parts = self.model_parts | ||
parallel_dims = self.parallel_dims | ||
|
||
inputs = input_dict["tokens"] | ||
optional_context_parallel_ctx = ( | ||
dist_utils.create_context_parallel_ctx( | ||
cp_mesh=parallel_dims.world_mesh["cp"], | ||
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], | ||
cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||
cp_no_restore_buffers={inputs, labels}, | ||
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, | ||
) | ||
if parallel_dims.cp_enabled | ||
else None | ||
) | ||
|
||
if parallel_dims.pp_enabled: | ||
# Pipeline Parallel forward only | ||
with self.train_context(optional_context_parallel_ctx): | ||
targets, losses = ( | ||
(labels, []) if self.pp_has_last_stage else (None, None) | ||
) | ||
if self.pp_has_first_stage: | ||
self.pp_schedule.step( | ||
inputs, target=targets, losses=losses, input_batch=inputs | ||
) | ||
else: | ||
self.pp_schedule.step( | ||
target=targets, losses=losses, input_batch=inputs | ||
) | ||
|
||
loss = ( | ||
torch.mean(torch.stack(losses)).to(self.device) | ||
if self.pp_has_last_stage | ||
else torch.tensor([-1.0], device=self.device) | ||
) | ||
else: | ||
# Non-PP forward only | ||
with self.train_context(optional_context_parallel_ctx): | ||
assert len(model_parts) == 1 | ||
with self.maybe_enable_amp: | ||
pred = model_parts[0](inputs) | ||
loss = self.loss_fn(pred, labels) | ||
del pred | ||
|
||
return loss | ||
|
||
def train_step(self, batch) -> None: | ||
# TODO | ||
# with GradientAccumulation( | ||
# self.gradient_accumulation_steps, | ||
# self.model, | ||
# self.data_parallel_size, | ||
# ) as grad_acc: | ||
labels = batch.pop("labels") | ||
loss = self.forward_backward(batch, labels) | ||
|
||
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") | ||
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}") | ||
# self.pbar.update(1) | ||
self.optimizers.step() | ||
self.lr_schedulers.step() | ||
|
||
def _extract_epoch_from_batch(self, batch: dict) -> int | None: | ||
"""Extract epoch number from batch metrics.""" | ||
if "metrics" not in batch: | ||
return None | ||
|
||
for metric in batch["metrics"]: | ||
if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs": | ||
return metric.value | ||
return None | ||
|
||
async def evaluate(self) -> dict[str, float]: | ||
"""Run evaluation on validation set for one complete epoch. | ||
|
||
Uses prefetch + non-blocking all_reduce pattern to detect epoch completion | ||
across all ranks without blocking on every batch. | ||
|
||
Pattern: | ||
- Iteration N: Start async all_reduce on next batch's epoch (non-blocking) | ||
- Process current batch while all_reduce completes in background | ||
- Iteration N+1: Check result from previous all_reduce (should be done) | ||
|
||
This overlaps communication with computation for better performance. | ||
""" | ||
logger.info("=" * 50) | ||
logger.info("STARTING EVALUATION ") | ||
logger.info("=" * 50) | ||
|
||
# Set model to eval mode | ||
for model_part in self.model_parts: | ||
model_part.eval() | ||
|
||
val_dataloader = iter(self.val_dataloader) | ||
total_loss = 0.0 | ||
num_batches = 0 | ||
starting_epoch = None | ||
|
||
# Prefetch first batch | ||
try: | ||
next_batch = next(val_dataloader) | ||
except StopIteration: | ||
logger.warning("Validation dataloader is empty") | ||
return {"val_loss": 0.0, "val_batches": 0} | ||
|
||
next_should_break = False | ||
pending_work = None # Handle for async all_reduce | ||
epoch_tensor = None # Tensor for all_reduce result | ||
|
||
with torch.no_grad(): | ||
while True: | ||
# Check result from PREVIOUS iteration's async all_reduce | ||
if pending_work is not None: | ||
pending_work.wait() # Should be complete (or very fast) since we did compute | ||
if epoch_tensor is not None: | ||
next_should_break = epoch_tensor.item() > 0 | ||
pending_work = None | ||
|
||
# Check if we should break (based on previous iteration's check) | ||
if next_should_break: | ||
logger.info( | ||
"Epoch completed across all ranks - stopping evaluation" | ||
) | ||
break | ||
|
||
# Check optional cap on eval steps | ||
if self.eval_steps > 0 and num_batches >= self.eval_steps: | ||
logger.info(f"Reached eval_steps cap of {self.eval_steps}") | ||
break | ||
|
||
# Use the batch that was prefetched in previous iteration | ||
batch = next_batch | ||
|
||
# Extract epoch from current batch | ||
current_epoch = self._extract_epoch_from_batch(batch) | ||
if current_epoch is not None and starting_epoch is None: | ||
starting_epoch = current_epoch | ||
logger.info(f"Starting evaluation at epoch {starting_epoch}") | ||
|
||
# Prefetch next batch and start async all_reduce | ||
try: | ||
next_batch = next(val_dataloader) | ||
|
||
# Extract epoch from next batch | ||
next_epoch = self._extract_epoch_from_batch(next_batch) | ||
|
||
# Start NON-BLOCKING all_reduce to check if any rank completed epoch | ||
if next_epoch is not None and starting_epoch is not None: | ||
# Check if next batch indicates epoch completion | ||
epoch_increment = next_epoch - starting_epoch | ||
|
||
if torch.distributed.is_initialized(): | ||
# Create tensor for all_reduce | ||
epoch_tensor = torch.tensor( | ||
[epoch_increment], dtype=torch.long, device=self.device | ||
) | ||
# Start async all_reduce (returns immediately, doesn't block) | ||
pending_work = torch.distributed.all_reduce( | ||
epoch_tensor, | ||
op=torch.distributed.ReduceOp.MAX, | ||
async_op=True, # NON-BLOCKING - returns immediately | ||
) | ||
else: | ||
# Single rank case - just check locally | ||
next_should_break = epoch_increment > 0 | ||
|
||
except StopIteration: | ||
# No more batches - this is the last one | ||
next_should_break = True | ||
|
||
# Process current batch (while all_reduce completes in background) | ||
# Move tensors to device | ||
for k, v in batch.items(): | ||
if isinstance(v, torch.Tensor): | ||
batch[k] = v.to(self.device) | ||
|
||
labels = batch.pop("labels") | ||
loss = self.forward_only(batch, labels) | ||
# GPU compute happens here while network does all_reduce | ||
|
||
total_loss += loss.item() | ||
num_batches += 1 | ||
|
||
eval_steps_info = f"/{self.eval_steps}" if self.eval_steps > 0 else "" | ||
logger.info( | ||
f" Eval batch {num_batches}{eval_steps_info} | Loss: {loss.item():.4f}" | ||
) | ||
|
||
# Set model back to train mode | ||
for model_part in self.model_parts: | ||
model_part.train() | ||
|
||
avg_loss = total_loss / max(num_batches, 1) | ||
|
||
metrics = { | ||
"val_loss": avg_loss, | ||
"val_batches": num_batches, | ||
} | ||
|
||
logger.info("-" * 50) | ||
logger.info(f"EVALUATION COMPLETE") | ||
logger.info(f"Validation Loss: {avg_loss:.4f}") | ||
logger.info(f"Batches Evaluated: {num_batches}") | ||
logger.info("=" * 50) | ||
return metrics | ||
|
||
@endpoint | ||
async def train(self) -> None: | ||
dataloader = iter(self.train_dataloader) | ||
self.optimizers.zero_grad() | ||
|
||
# TODO: tqdm is broken in Monarch actors | ||
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps) | ||
|
||
|
@@ -254,18 +435,21 @@ async def train(self) -> None: | |
# Move tensors to the appropriate device | ||
for k, v in batch.items(): | ||
if isinstance(v, torch.Tensor): | ||
batch[k] = v.to("cuda") # TODO: hardcoded for now | ||
batch[k] = v.to(self.device) # TODO: hardcoded for now | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.train_step(batch) | ||
# self.profiler.step() | ||
self.current_step += 1 | ||
|
||
# Run evaluation periodically | ||
if self.current_step % self.eval_interval == 0: | ||
eval_metrics = await self.evaluate() | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}") | ||
|
||
# Save checkpoints | ||
self.checkpointer.save( | ||
curr_step=self.current_step, | ||
last_step=self.current_step == self.num_training_steps, | ||
) | ||
|
||
# self.pbar.close() | ||
|
||
@endpoint | ||
async def cleanup(self) -> None: | ||
if self.checkpointer: | ||
|
Uh oh!
There was an error while loading. Please reload this page.