-
Notifications
You must be signed in to change notification settings - Fork 16
make dataset configurable and add validation loop #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
70589d5
1e8971c
8512860
0590116
f458714
99dfd19
0430399
5212922
102ff28
39e1fb1
abf6043
319847f
91b744c
58635c8
07659f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -18,9 +18,11 @@ | |
from forge.data.datasets.packed import PackedDataset, TextPacker | ||
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset | ||
from forge.data.tokenizer import HuggingFaceModelTokenizer | ||
from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX | ||
|
||
from omegaconf import DictConfig, OmegaConf | ||
from torch import nn | ||
|
||
from torchdata.stateful_dataloader import StatefulDataLoader | ||
from torchtitan.components.loss import LossFunction | ||
from torchtitan.components.lr_scheduler import LRSchedulersContainer | ||
|
@@ -30,6 +32,7 @@ | |
from torchtitan.experiments.forge.job_config import ForgeJobConfig | ||
from tqdm import tqdm | ||
|
||
|
||
# stubs for now | ||
Checkpointer = Any | ||
Dataloader = Any | ||
|
@@ -63,7 +66,16 @@ def __init__(self, job_config: ForgeJobConfig): | |
self.metric_logger = None # TODO: fix this | ||
|
||
def setup(self): | ||
self.train_dataloader = self.setup_data() | ||
self.train_dataloader = self.setup_data( | ||
self.job_config.dataset, | ||
batch_size=self.job_config.training.local_batch_size, | ||
) | ||
|
||
self.val_dataloader = self.setup_data( | ||
self.job_config.dataset_val, | ||
batch_size=self.job_config.validation.local_batch_size, | ||
) | ||
|
||
# self.train_dataloader = self.setup_data( | ||
# self.train_config.train_dataset_config, | ||
# self.train_config.train_dataloader_config, | ||
|
@@ -79,7 +91,7 @@ def setup(self): | |
# self.profiler = self.setup_profiler(self.train_config.profiler_config) | ||
# self.logger = self.setup_logger(self.train_config.logger_config) | ||
|
||
def setup_data(self): | ||
def setup_data(self, dataset_config, batch_size): | ||
tokenizer = HuggingFaceModelTokenizer( | ||
tokenizer_json_path=os.path.join( | ||
self.job_config.model.hf_assets_path, "tokenizer.json" | ||
|
@@ -95,8 +107,8 @@ def setup_data(self): | |
dataset = sft_iterable_dataset( | ||
model_transform=tokenizer, | ||
message_transform=AlpacaToMessages(), | ||
path="yahma/alpaca-cleaned", | ||
split="train", | ||
path=dataset_config.path, | ||
split=dataset_config.split, | ||
) | ||
packer = TextPacker(padding_idx=0) | ||
dataset = PackedDataset( | ||
|
@@ -106,7 +118,7 @@ def setup_data(self): | |
) | ||
dataloader = StatefulDataLoader( | ||
dataset=dataset, | ||
batch_size=self.job_config.training.local_batch_size, | ||
batch_size=batch_size, | ||
collate_fn=partial( | ||
collate_packed, mask_fn=packer.create_block_mask, device=self.device | ||
), | ||
|
@@ -119,7 +131,10 @@ def setup_data(self): | |
return dataloader | ||
|
||
def forward_backward( | ||
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor | ||
self, | ||
input_dict: dict[str, torch.Tensor], | ||
labels: torch.Tensor, | ||
do_backward: bool = True, | ||
) -> torch.Tensor: | ||
model_parts = self.model_parts | ||
parallel_dims = self.parallel_dims | ||
|
@@ -145,14 +160,16 @@ def forward_backward( | |
targets, losses = ( | ||
(labels, []) if self.pp_has_last_stage else (None, None) | ||
) | ||
if do_backward: | ||
pp_schedule_fn = self.pp_schedule.step | ||
else: | ||
pp_schedule_fn = self.pp_schedule.eval | ||
if self.pp_has_first_stage: | ||
self.pp_schedule.step( | ||
pp_schedule_fn( | ||
inputs, target=targets, losses=losses, input_batch=inputs | ||
) | ||
else: | ||
self.pp_schedule.step( | ||
target=targets, losses=losses, input_batch=inputs | ||
) | ||
pp_schedule_fn(target=targets, losses=losses, input_batch=inputs) | ||
|
||
# accumulate losses across pipeline microbatches | ||
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU | ||
|
@@ -170,7 +187,8 @@ def forward_backward( | |
loss = self.loss_fn(pred, labels) | ||
# need to free to before bwd to avoid peaking memory | ||
del pred | ||
loss.backward() | ||
if do_backward: | ||
|
||
loss.backward() | ||
|
||
return loss | ||
|
||
|
@@ -214,6 +232,52 @@ def train(self) -> None: | |
last_step=self.current_step == self.num_training_steps, | ||
) | ||
|
||
if ( | ||
self.job_config.validation.freq > 0 | ||
and self.job_config.validation.steps > 0 | ||
and self.current_step % self.job_config.validation.freq == 0 | ||
): | ||
self.validate(self.job_config.validation.steps) | ||
|
||
def validate(self, max_steps: int) -> None: | ||
for m in self.model_parts: | ||
m.eval() | ||
total_val_loss = torch.tensor(0.0, device=self.device) | ||
total_val_tokens = torch.tensor(0.0, device=self.device) | ||
with torch.no_grad(): | ||
val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False) | ||
for batch_idx, batch in enumerate(val_pbar): | ||
if batch_idx >= max_steps: | ||
break | ||
batch_to_device(batch, self.device) | ||
current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum() | ||
# Compute loss | ||
labels = batch.pop("labels") | ||
loss = self.forward_backward(batch, labels, do_backward=False) | ||
val_loss = loss * current_num_tokens | ||
total_val_loss += val_loss | ||
total_val_tokens += current_num_tokens | ||
# Update progress bar description with current average loss | ||
avg_loss_so_far = ( | ||
(total_val_loss / total_val_tokens).item() | ||
if total_val_tokens > 0 | ||
else float("inf") | ||
) | ||
val_pbar.set_description( | ||
f"Running validation Loss: {avg_loss_so_far:.4f}" | ||
) | ||
# Aggregate validation metrics across all ranks | ||
torch.distributed.all_reduce(total_val_loss) | ||
torch.distributed.all_reduce(total_val_tokens) | ||
avg_val_loss = ( | ||
(total_val_loss / total_val_tokens).item() | ||
if total_val_tokens > 0 | ||
else float("inf") | ||
) | ||
for m in self.model_parts: | ||
m.train() | ||
print(f"\nValidation loss: {avg_val_loss}") | ||
|
||
def cleanup(self) -> None: | ||
if self.checkpointer: | ||
self.checkpointer.close() | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One thing we should think about is how to support additional args beyond those we've already hardcoded. E.g. in #50 we also need to pass
data_files
. (This is more of a config system question so it's OK to punt on it for now, but one path is to use something like instantiate for this, you can see this section in the torchtune docs for an example)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can support passing file paths. Which one (
data_files
orpath
) should it prioritize? For example, if user pass bothdata_files
andpath