-
Notifications
You must be signed in to change notification settings - Fork 16
Adding eval to the SFT #404
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
3653453
baeb35b
7550664
a0f62e7
53371c6
4793948
676db88
250c0cd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,6 @@ | |
"""To run: | ||
|
||
python -m apps.sft.main --config apps/sft/llama3_8b.yaml | ||
|
||
""" | ||
|
||
import asyncio | ||
|
@@ -40,8 +39,6 @@ | |
from torchtitan.experiments.forge.engine import ForgeEngine | ||
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
|
||
# from tqdm import tqdm | ||
|
||
# stubs for now | ||
Checkpointer = Any | ||
Dataloader = Any | ||
|
@@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine): | |
checkpointer: Checkpointer | ||
tokenizer: Tokenizer | ||
train_dataloader: Dataloader | ||
# val_dataloader: Dataloader | ||
val_dataloader: Dataloader | ||
metric_logger: MetricLogger | ||
profiler: Profiler | ||
device: torch.device | ||
|
@@ -81,6 +78,11 @@ def __init__(self, config: DictConfig): | |
self.gradient_accumulation_steps = 1 # Example value, adjust as needed | ||
self._rank = current_rank().rank | ||
self._size = math.prod(current_size().values()) | ||
|
||
# Evaluation settings | ||
self.eval_interval = job_config.training.get("eval_interval", float("inf")) | ||
self.eval_steps = job_config.training.get("eval_steps", 0) | ||
|
||
self._init_dist() | ||
super().__init__(job_config) | ||
|
||
|
@@ -111,25 +113,23 @@ def _init_dist(self): | |
|
||
@endpoint | ||
async def setup(self): | ||
self.train_dataloader = self.setup_data() | ||
# self.train_dataloader = self.setup_data( | ||
# self.train_config.train_dataset_config, | ||
# self.train_config.train_dataloader_config, | ||
# self.train_config.packing_config, | ||
# ) | ||
# self.val_dataloader = self.setup_data( | ||
# self.train_config.val_dataset_config, | ||
# self.train_config.val_dataloader_config, | ||
# self.train_config.packing_config, | ||
# ) | ||
|
||
# TODO: confirm that this is working properly | ||
# Should also use load, not dcp_load | ||
# Setup training data (first 90% of train split) | ||
self.train_dataloader = self.setup_data( | ||
dataset_path="yahma/alpaca-cleaned", dataset_split="train[:90%]" | ||
HosseinKaviani-H marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
) | ||
|
||
# Setup validation data (last 10% of train split) | ||
self.val_dataloader = self.setup_data( | ||
dataset_path="yahma/alpaca-cleaned", dataset_split="train[90%:]" | ||
) | ||
|
||
# Load checkpoint if resuming | ||
self.checkpointer.load(step=self.current_step) | ||
# self.profiler = self.setup_profiler(self.train_config.profiler_config) | ||
# self.logger = self.setup_logger(self.train_config.logger_config) | ||
|
||
def setup_data(self): | ||
def setup_data( | ||
self, dataset_path: str = "yahma/alpaca-cleaned", dataset_split: str = "train" | ||
HosseinKaviani-H marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
): | ||
"""Setup data with configurable dataset path and split.""" | ||
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) | ||
tokenizer = HuggingFaceModelTokenizer( | ||
tokenizer_json_path=os.path.join( | ||
|
@@ -146,8 +146,8 @@ def setup_data(self): | |
dataset = sft_iterable_dataset( | ||
model_transform=tokenizer, | ||
message_transform=AlpacaToMessages(), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree. We can have this implemented as soon as we fix the main eval functionality |
||
path="yahma/alpaca-cleaned", | ||
split="train", | ||
path=dataset_path, | ||
split=dataset_split, | ||
) | ||
packer = TextPacker(padding_idx=0) | ||
dataset = PackedDataset( | ||
|
@@ -163,10 +163,6 @@ def setup_data(self): | |
), | ||
) | ||
|
||
# Ultimately we probably want something like this | ||
# packer = build_packing_strategy(packing_config) | ||
# dataset = build_dataset(dataset_config) | ||
# dataloader = build_dataloader(dataloader_config, dataset, packer) | ||
return dataloader | ||
|
||
def forward_backward( | ||
|
@@ -206,7 +202,6 @@ def forward_backward( | |
) | ||
|
||
# accumulate losses across pipeline microbatches | ||
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
loss = ( | ||
torch.mean(torch.stack(losses)).to(self.device) | ||
if self.pp_has_last_stage | ||
|
@@ -225,27 +220,125 @@ def forward_backward( | |
|
||
return loss | ||
|
||
def forward_only( | ||
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||
) -> torch.Tensor: | ||
"""Forward pass only for evaluation (no backward).""" | ||
model_parts = self.model_parts | ||
parallel_dims = self.parallel_dims | ||
|
||
inputs = input_dict["tokens"] | ||
optional_context_parallel_ctx = ( | ||
dist_utils.create_context_parallel_ctx( | ||
cp_mesh=parallel_dims.world_mesh["cp"], | ||
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], | ||
cp_seq_dims=[1, 1] + [0 for _ in model_parts], | ||
cp_no_restore_buffers={inputs, labels}, | ||
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, | ||
) | ||
if parallel_dims.cp_enabled | ||
else None | ||
) | ||
|
||
if parallel_dims.pp_enabled: | ||
# Pipeline Parallel forward only | ||
with self.train_context(optional_context_parallel_ctx): | ||
targets, losses = ( | ||
(labels, []) if self.pp_has_last_stage else (None, None) | ||
) | ||
if self.pp_has_first_stage: | ||
self.pp_schedule.step( | ||
inputs, target=targets, losses=losses, input_batch=inputs | ||
) | ||
else: | ||
self.pp_schedule.step( | ||
target=targets, losses=losses, input_batch=inputs | ||
) | ||
|
||
loss = ( | ||
torch.mean(torch.stack(losses)).to(self.device) | ||
if self.pp_has_last_stage | ||
else torch.tensor([-1.0], device=self.device) | ||
) | ||
else: | ||
# Non-PP forward only | ||
with self.train_context(optional_context_parallel_ctx): | ||
assert len(model_parts) == 1 | ||
with self.maybe_enable_amp: | ||
pred = model_parts[0](inputs) | ||
loss = self.loss_fn(pred, labels) | ||
del pred | ||
|
||
return loss | ||
|
||
def train_step(self, batch) -> None: | ||
# TODO | ||
# with GradientAccumulation( | ||
# self.gradient_accumulation_steps, | ||
# self.model, | ||
# self.data_parallel_size, | ||
# ) as grad_acc: | ||
labels = batch.pop("labels") | ||
loss = self.forward_backward(batch, labels) | ||
|
||
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") | ||
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}") | ||
# self.pbar.update(1) | ||
self.optimizers.step() | ||
self.lr_schedulers.step() | ||
|
||
async def evaluate(self) -> dict[str, float]: | ||
"""Run evaluation on validation set (internal method, not an endpoint).""" | ||
logger.info("=" * 50) | ||
logger.info("STARTING EVALUATION ") | ||
logger.info("=" * 50) | ||
|
||
# Set model to eval mode | ||
for model_part in self.model_parts: | ||
model_part.eval() | ||
|
||
val_dataloader = iter(self.val_dataloader) | ||
total_loss = 0.0 | ||
num_batches = 0 | ||
|
||
with torch.no_grad(): | ||
for step in range(self.eval_steps): | ||
try: | ||
batch = next(val_dataloader) | ||
|
||
# Move tensors to device | ||
for k, v in batch.items(): | ||
if isinstance(v, torch.Tensor): | ||
batch[k] = v.to(self.device) | ||
|
||
labels = batch.pop("labels") | ||
loss = self.forward_only(batch, labels) | ||
|
||
total_loss += loss.item() | ||
num_batches += 1 | ||
|
||
logger.info( | ||
f" Eval batch {num_batches}/{self.eval_steps} | Loss: {loss.item():.4f}" | ||
) | ||
|
||
except StopIteration: | ||
logger.warning("Reached end of validation dataloader early") | ||
break | ||
|
||
# Set model back to train mode | ||
for model_part in self.model_parts: | ||
model_part.train() | ||
|
||
avg_loss = total_loss / max(num_batches, 1) | ||
|
||
metrics = { | ||
"val_loss": avg_loss, | ||
"val_batches": num_batches, | ||
} | ||
|
||
logger.info("-" * 50) | ||
logger.info(f"EVALUATION COMPLETE") | ||
logger.info(f"Validation Loss: {avg_loss:.4f}") | ||
logger.info(f"Batches Evaluated: {num_batches}") | ||
logger.info("=" * 50) | ||
return metrics | ||
|
||
@endpoint | ||
async def train(self) -> None: | ||
dataloader = iter(self.train_dataloader) | ||
self.optimizers.zero_grad() | ||
|
||
# TODO: tqdm is broken in Monarch actors | ||
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps) | ||
|
||
|
@@ -254,18 +347,21 @@ async def train(self) -> None: | |
# Move tensors to the appropriate device | ||
for k, v in batch.items(): | ||
if isinstance(v, torch.Tensor): | ||
batch[k] = v.to("cuda") # TODO: hardcoded for now | ||
batch[k] = v.to(self.device) # TODO: hardcoded for now | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.train_step(batch) | ||
# self.profiler.step() | ||
self.current_step += 1 | ||
|
||
# Run evaluation periodically | ||
if self.current_step % self.eval_interval == 0: | ||
eval_metrics = await self.evaluate() | ||
HosseinKaviani-H marked this conversation as resolved.
Show resolved
Hide resolved
|
||
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}") | ||
|
||
# Save checkpoints | ||
self.checkpointer.save( | ||
curr_step=self.current_step, | ||
last_step=self.current_step == self.num_training_steps, | ||
) | ||
|
||
# self.pbar.close() | ||
|
||
@endpoint | ||
async def cleanup(self) -> None: | ||
if self.checkpointer: | ||
|
Uh oh!
There was an error while loading. Please reload this page.