-
Notifications
You must be signed in to change notification settings - Fork 18
Adding eval to the SFT #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
3653453
baeb35b
7550664
a0f62e7
53371c6
4793948
676db88
250c0cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||
---|---|---|---|---|
|
@@ -7,7 +7,6 @@ | |||
"""To run: | ||||
|
||||
python -m apps.sft.main --config apps/sft/llama3_8b.yaml | ||||
|
||||
""" | ||||
|
||||
import asyncio | ||||
|
@@ -40,8 +39,6 @@ | |||
from torchtitan.experiments.forge.engine import ForgeEngine | ||||
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||||
|
||||
# from tqdm import tqdm | ||||
|
||||
# stubs for now | ||||
Checkpointer = Any | ||||
Dataloader = Any | ||||
|
@@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine): | |||
checkpointer: Checkpointer | ||||
tokenizer: Tokenizer | ||||
train_dataloader: Dataloader | ||||
# val_dataloader: Dataloader | ||||
val_dataloader: Dataloader | ||||
metric_logger: MetricLogger | ||||
profiler: Profiler | ||||
device: torch.device | ||||
|
@@ -81,6 +78,27 @@ def __init__(self, config: DictConfig): | |||
self.gradient_accumulation_steps = 1 # Example value, adjust as needed | ||||
self._rank = current_rank().rank | ||||
self._size = math.prod(current_size().values()) | ||||
|
||||
# Evaluation settings from validation config | ||||
validation_config = job_config.get("validation", {}) | ||||
self.validation_enabled = validation_config.get("enabled", False) | ||||
|
||||
if self.validation_enabled: | ||||
self.eval_interval = validation_config.get("eval_interval") | ||||
self.eval_steps = validation_config.get("eval_steps") | ||||
|
||||
if self.eval_interval is None: | ||||
raise ValueError( | ||||
"validation.eval_interval is required when validation.enabled is true" | ||||
) | ||||
if self.eval_steps is None: | ||||
raise ValueError( | ||||
"validation.eval_steps is required when validation.enabled is true" | ||||
) | ||||
else: | ||||
self.eval_interval = None | ||||
self.eval_steps = None | ||||
|
||||
self._init_dist() | ||||
super().__init__(job_config) | ||||
|
||||
|
@@ -111,25 +129,30 @@ def _init_dist(self): | |||
|
||||
@endpoint | ||||
async def setup(self): | ||||
self.train_dataloader = self.setup_data() | ||||
# self.train_dataloader = self.setup_data( | ||||
# self.train_config.train_dataset_config, | ||||
# self.train_config.train_dataloader_config, | ||||
# self.train_config.packing_config, | ||||
# ) | ||||
# self.val_dataloader = self.setup_data( | ||||
# self.train_config.val_dataset_config, | ||||
# self.train_config.val_dataloader_config, | ||||
# self.train_config.packing_config, | ||||
# ) | ||||
|
||||
# TODO: confirm that this is working properly | ||||
# Should also use load, not dcp_load | ||||
# Setup training data from config | ||||
dataset_config = self.job_config.get("dataset") | ||||
|
||||
self.train_dataloader = self.setup_data( | ||||
dataset_path=dataset_config.get("path"), | ||||
dataset_split=dataset_config.get("split"), | ||||
) | ||||
|
||||
# Setup validation data from config | ||||
dataset_val_config = self.job_config.get("dataset_val", {}) | ||||
self.val_dataloader = self.setup_data( | ||||
dataset_path=dataset_val_config.get("path", dataset_config.get("path")), | ||||
dataset_split=dataset_val_config.get("split", dataset_config.get("split")), | ||||
) | ||||
|
||||
# Load checkpoint if resuming | ||||
self.checkpointer.load(step=self.current_step) | ||||
# self.profiler = self.setup_profiler(self.train_config.profiler_config) | ||||
# self.logger = self.setup_logger(self.train_config.logger_config) | ||||
|
||||
def setup_data(self): | ||||
def setup_data(self, dataset_path: str, dataset_split: str): | ||||
"""Setup data with configurable dataset path and split.""" | ||||
if not dataset_path or not dataset_split: | ||||
raise ValueError( | ||||
f"dataset.path and dataset.split are required in YAML config. Got path={dataset_path}, split={dataset_split}" | ||||
) | ||||
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) | ||||
tokenizer = HuggingFaceModelTokenizer( | ||||
tokenizer_json_path=os.path.join( | ||||
|
@@ -146,8 +169,8 @@ def setup_data(self): | |||
dataset = sft_iterable_dataset( | ||||
model_transform=tokenizer, | ||||
message_transform=AlpacaToMessages(), | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. We can have this implemented as soon as we fix the main eval functionality |
||||
path="yahma/alpaca-cleaned", | ||||
split="train", | ||||
path=dataset_path, | ||||
split=dataset_split, | ||||
) | ||||
packer = TextPacker(padding_idx=0) | ||||
dataset = PackedDataset( | ||||
|
@@ -163,10 +186,6 @@ def setup_data(self): | |||
), | ||||
) | ||||
|
||||
# Ultimately we probably want something like this | ||||
# packer = build_packing_strategy(packing_config) | ||||
# dataset = build_dataset(dataset_config) | ||||
# dataloader = build_dataloader(dataloader_config, dataset, packer) | ||||
return dataloader | ||||
|
||||
def forward_backward( | ||||
|
@@ -206,7 +225,6 @@ def forward_backward( | |||
) | ||||
|
||||
# accumulate losses across pipeline microbatches | ||||
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||||
loss = ( | ||||
torch.mean(torch.stack(losses)).to(self.device) | ||||
if self.pp_has_last_stage | ||||
|
@@ -225,27 +243,173 @@ def forward_backward( | |||
|
||||
return loss | ||||
|
||||
def forward_only( | ||||
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||||
) -> torch.Tensor: | ||||
"""Forward pass only for evaluation (no backward).""" | ||||
model_parts = self.model_parts | ||||
parallel_dims = self.parallel_dims | ||||
|
||||
inputs = input_dict["tokens"] | ||||
optional_context_parallel_ctx = ( | ||||
dist_utils.create_context_parallel_ctx( | ||||
cp_mesh=parallel_dims.world_mesh["cp"], | ||||
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], | ||||
cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||||
cp_no_restore_buffers={inputs, labels}, | ||||
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, | ||||
) | ||||
if parallel_dims.cp_enabled | ||||
else None | ||||
) | ||||
|
||||
if parallel_dims.pp_enabled: | ||||
# Pipeline Parallel forward only | ||||
with self.train_context(optional_context_parallel_ctx): | ||||
targets, losses = ( | ||||
(labels, []) if self.pp_has_last_stage else (None, None) | ||||
) | ||||
if self.pp_has_first_stage: | ||||
self.pp_schedule.step( | ||||
inputs, target=targets, losses=losses, input_batch=inputs | ||||
) | ||||
else: | ||||
self.pp_schedule.step( | ||||
target=targets, losses=losses, input_batch=inputs | ||||
) | ||||
|
||||
loss = ( | ||||
torch.mean(torch.stack(losses)).to(self.device) | ||||
if self.pp_has_last_stage | ||||
else torch.tensor([-1.0], device=self.device) | ||||
) | ||||
else: | ||||
# Non-PP forward only | ||||
with self.train_context(optional_context_parallel_ctx): | ||||
assert len(model_parts) == 1 | ||||
with self.maybe_enable_amp: | ||||
pred = model_parts[0](inputs) | ||||
loss = self.loss_fn(pred, labels) | ||||
del pred | ||||
|
||||
return loss | ||||
|
||||
def train_step(self, batch) -> None: | ||||
# TODO | ||||
# with GradientAccumulation( | ||||
# self.gradient_accumulation_steps, | ||||
# self.model, | ||||
# self.data_parallel_size, | ||||
# ) as grad_acc: | ||||
labels = batch.pop("labels") | ||||
loss = self.forward_backward(batch, labels) | ||||
|
||||
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") | ||||
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}") | ||||
# self.pbar.update(1) | ||||
self.optimizers.step() | ||||
self.lr_schedulers.step() | ||||
|
||||
def _extract_epoch_from_batch(self, batch: dict) -> int | None: | ||||
"""Extract epoch number from batch metrics.""" | ||||
if "metrics" in batch: | ||||
for metric in batch["metrics"]: | ||||
if ( | ||||
hasattr(metric, "metric_name") | ||||
and metric.metric_name == "num_epochs" | ||||
): | ||||
return metric.value | ||||
return None | ||||
|
||||
async def evaluate(self) -> dict[str, float]: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In addition to @felipemello1's more detailed comments, one higher-level point: this eval implementation is adding a lot of code to the main.py file, which as the entry point is something that everyone will have to read. (Specifically this PR alone has increased the total LoC by more than 50%, and the evaluate method alone is more than 100 lines due to boundary checking, edge case handling, etc.) I would like to see if we can find a more minimal way to introduce eval that doesn't expose the user to so much code complexity. I think Felipe's suggestions of offloading to the dataset class, utilities, etc. are valuable. But would also like to re-raise the option of simplifying by only allowing eval for a fixed number of steps (at least for a first pass). Not gonna block on this, if we can do the cross-epoch accounting in a bit more of a clean, minimal way I am all for it. |
||||
"""Run evaluation with async all_reduce for cross-rank epoch synchronization.""" | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it might be worth enhancing this docstring a bit. Maybe add a small numerical example. |
||||
logger.info("=" * 50) | ||||
logger.info("STARTING EVALUATION") | ||||
logger.info("=" * 50) | ||||
|
||||
for model_part in self.model_parts: | ||||
model_part.eval() | ||||
|
||||
val_dataloader = iter(self.val_dataloader) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. did you have a chance to consider the case for multidataset? what happens in this case? |
||||
total_loss, num_batches, starting_epoch = 0.0, 0, None | ||||
|
||||
# Prefetch first batch | ||||
try: | ||||
next_batch = next(val_dataloader) | ||||
except StopIteration: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we can remove the defensive checks and assume that the dataset is infinite. We have a class for it. I think you can just do an assertion that its TuneIterableDataset (we need to update the name and remove "tune", but dont worry about this on this PR: torchforge/src/forge/data/datasets/dataset.py Line 117 in d464193
Then we know its infinite, and we can remove try/except here and later in the loop. wdyt? |
||||
logger.warning("Validation dataloader is empty") | ||||
return {"val_loss": 0.0, "val_batches": 0} | ||||
|
||||
should_break, pending_work, epoch_tensor = False, None, None | ||||
|
||||
with torch.no_grad(): | ||||
while True: | ||||
# Wait for previous async all_reduce to complete | ||||
if pending_work is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am thinking we could abstract most of it into some utility and have this (feel free to change var names) epoch_incremented, next_max_epoch = False, None
with torch.no_grad():
while True:
# check if epoch incremented before getting new batch.
# If so, stop iterating on the dataset
epoch_incremented: bool = check_if_epoch_incremented(batch, next_max_epoch)
if epoch_incremented:
logger.info("bla bla bla")
break
# get next batch
batch = next_batch
next_batch = next(val_dataloader)
# start non-blocking all_reduce for next batches epoch
next_max_epoch: futures = get_distributed_max_epoch(next_batch) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not 100% sure this works. I think that |
||||
pending_work.wait() | ||||
should_break = ( | ||||
epoch_tensor.item() > 0 if epoch_tensor is not None else False | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as a rule of thumb, its not good to use item(), because it requires cpu-gpu synchonization |
||||
) | ||||
pending_work = None | ||||
|
||||
if should_break: | ||||
logger.info( | ||||
"Epoch completed across all ranks - stopping evaluation" | ||||
) | ||||
break | ||||
|
||||
if self.eval_steps > 0 and num_batches >= self.eval_steps: | ||||
logger.info(f"Reached eval_steps cap of {self.eval_steps}") | ||||
break | ||||
|
||||
batch = next_batch | ||||
|
||||
# Track starting epoch | ||||
current_epoch = self._extract_epoch_from_batch(batch) | ||||
if current_epoch is not None and starting_epoch is None: | ||||
starting_epoch = current_epoch | ||||
|
||||
# Prefetch next batch and start async epoch check | ||||
try: | ||||
next_batch = next(val_dataloader) | ||||
next_epoch = self._extract_epoch_from_batch(next_batch) | ||||
|
||||
if next_epoch is not None and starting_epoch is not None: | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i am tempted to say lets remove the if/else check for None. Do you see a strong argument for keeping them? It would only be a problem if someone replaced our dataset abstraction, otherwise we always have the metric |
||||
epoch_increment = next_epoch - starting_epoch | ||||
if torch.distributed.is_initialized(): | ||||
epoch_tensor = torch.tensor( | ||||
[epoch_increment], dtype=torch.long, device=self.device | ||||
) | ||||
pending_work = torch.distributed.all_reduce( | ||||
epoch_tensor, | ||||
op=torch.distributed.ReduceOp.MAX, | ||||
async_op=True, | ||||
) | ||||
else: | ||||
should_break = epoch_increment > 0 | ||||
except StopIteration: | ||||
should_break = True | ||||
|
||||
# Process current batch (overlaps with async all_reduce) | ||||
for k, v in batch.items(): | ||||
if isinstance(v, torch.Tensor): | ||||
batch[k] = v.to(self.device) | ||||
|
||||
labels = batch.pop("labels") | ||||
loss = self.forward_only(batch, labels) | ||||
total_loss += loss.item() | ||||
num_batches += 1 | ||||
|
||||
if num_batches % 10 == 0: | ||||
logger.info(f" Eval batch {num_batches} | Loss: {loss.item():.4f}") | ||||
|
||||
for model_part in self.model_parts: | ||||
model_part.train() | ||||
|
||||
avg_loss = total_loss / max(num_batches, 1) | ||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This might be incorrect. If we have one sample with 10 tokens and another sample with 100 tokens, we have to /110, not by 2. @ebsmothers can you confirm? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This kinda depends. One interpretation is that the val loss we report is the average over all batches in the val dataloader. In that case this would be the correct implementation. A second interpretation would be that the val loss we report is actually the aggregate loss over all tokens in the val dataset. This is the more "correct" one, as @felipemello1 points out. It especially matters for training.. btw this is why I did not enable gradient accumulation initially, as it needs some special handling. See e.g. the "Long digression" section in the summary of meta-pytorch/torchtune#1917. Since then titan has made some strides in this direction, see this function and this issue. So I think the ideal state is that we use token-normalized loss (rather than batch-normalized as has been done here) for both training and validation. But for training it will likely require a little more work to fully support in titan. For validation it's more straightforward, you can see this snippet in torchtune's validation logic. |
||||
logger.info( | ||||
f"EVALUATION COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}" | ||||
) | ||||
logger.info("=" * 50) | ||||
|
||||
return {"val_loss": avg_loss, "val_batches": num_batches} | ||||
|
||||
@endpoint | ||||
async def train(self) -> None: | ||||
dataloader = iter(self.train_dataloader) | ||||
self.optimizers.zero_grad() | ||||
|
||||
# TODO: tqdm is broken in Monarch actors | ||||
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps) | ||||
|
||||
|
@@ -254,18 +418,21 @@ async def train(self) -> None: | |||
# Move tensors to the appropriate device | ||||
for k, v in batch.items(): | ||||
if isinstance(v, torch.Tensor): | ||||
batch[k] = v.to("cuda") # TODO: hardcoded for now | ||||
batch[k] = v.to(self.device) # TODO: hardcoded for now | ||||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
self.train_step(batch) | ||||
# self.profiler.step() | ||||
self.current_step += 1 | ||||
|
||||
# Run evaluation periodically if enabled | ||||
if self.validation_enabled and self.current_step % self.eval_interval == 0: | ||||
eval_metrics = await self.evaluate() | ||||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||||
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}") | ||||
|
||||
# Save checkpoints | ||||
self.checkpointer.save( | ||||
curr_step=self.current_step, | ||||
last_step=self.current_step == self.num_training_steps, | ||||
) | ||||
|
||||
# self.pbar.close() | ||||
|
||||
@endpoint | ||||
async def cleanup(self) -> None: | ||||
if self.checkpointer: | ||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit but I don't like these nested .get calls. It also seems strange that we would fallback to validation on the training set. Personally I would just recommend checking if validation is enabled, and if it's not, don't even set up the validation dataloader at all.