Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,19 @@ training:
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"

validation:
local_batch_size: 1
freq: -1 # Change to a positive number to enable validation
steps: 200 # Max steps to run validation. Validation disabled if negative.

dataset:
path: yahma/alpaca-cleaned
split: train[:95%]

dataset_val:
path: yahma/alpaca-cleaned
split: train[95%:]

parallelism:
data_parallel_replicate_degree: 1
Expand Down
86 changes: 75 additions & 11 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX

from omegaconf import DictConfig, OmegaConf
from torch import nn

from torchdata.stateful_dataloader import StatefulDataLoader
from torchtitan.components.loss import LossFunction
from torchtitan.components.lr_scheduler import LRSchedulersContainer
Expand All @@ -30,6 +32,7 @@
from torchtitan.experiments.forge.job_config import ForgeJobConfig
from tqdm import tqdm


# stubs for now
Checkpointer = Any
Dataloader = Any
Expand Down Expand Up @@ -63,7 +66,16 @@ def __init__(self, job_config: ForgeJobConfig):
self.metric_logger = None # TODO: fix this

def setup(self):
self.train_dataloader = self.setup_data()
self.train_dataloader = self.setup_data(
self.job_config.dataset,
batch_size=self.job_config.training.local_batch_size,
)

self.val_dataloader = self.setup_data(
self.job_config.dataset_val,
batch_size=self.job_config.validation.local_batch_size,
)

# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
Expand All @@ -79,7 +91,7 @@ def setup(self):
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
def setup_data(self, dataset_config, batch_size):
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
self.job_config.model.hf_assets_path, "tokenizer.json"
Expand All @@ -95,8 +107,8 @@ def setup_data(self):
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
path="yahma/alpaca-cleaned",
split="train",
path=dataset_config.path,
split=dataset_config.split,
)
packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
Expand All @@ -106,7 +118,7 @@ def setup_data(self):
)
dataloader = StatefulDataLoader(
dataset=dataset,
batch_size=self.job_config.training.local_batch_size,
batch_size=batch_size,
collate_fn=partial(
collate_packed, mask_fn=packer.create_block_mask, device=self.device
),
Expand All @@ -119,7 +131,10 @@ def setup_data(self):
return dataloader

def forward_backward(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
self,
input_dict: dict[str, torch.Tensor],
labels: torch.Tensor,
do_backward: bool = True,
) -> torch.Tensor:
model_parts = self.model_parts
parallel_dims = self.parallel_dims
Expand All @@ -145,14 +160,16 @@ def forward_backward(
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if do_backward:
pp_schedule_fn = self.pp_schedule.step
else:
pp_schedule_fn = self.pp_schedule.eval
if self.pp_has_first_stage:
self.pp_schedule.step(
pp_schedule_fn(
inputs, target=targets, losses=losses, input_batch=inputs
)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
)
pp_schedule_fn(target=targets, losses=losses, input_batch=inputs)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
Expand All @@ -170,7 +187,8 @@ def forward_backward(
loss = self.loss_fn(pred, labels)
# need to free to before bwd to avoid peaking memory
del pred
loss.backward()
if do_backward:
loss.backward()

return loss

Expand Down Expand Up @@ -214,6 +232,52 @@ def train(self) -> None:
last_step=self.current_step == self.num_training_steps,
)

if (
self.job_config.validation.freq > 0
and self.job_config.validation.steps > 0
and self.current_step % self.job_config.validation.freq == 0
):
self.validate(self.job_config.validation.steps)

def validate(self, max_steps: int) -> None:
for m in self.model_parts:
m.eval()
total_val_loss = torch.tensor(0.0, device=self.device)
total_val_tokens = torch.tensor(0.0, device=self.device)
with torch.no_grad():
val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False)
for batch_idx, batch in enumerate(val_pbar):
if batch_idx >= max_steps:
break
batch_to_device(batch, self.device)
current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum()
# Compute loss
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels, do_backward=False)
val_loss = loss * current_num_tokens
total_val_loss += val_loss
total_val_tokens += current_num_tokens
# Update progress bar description with current average loss
avg_loss_so_far = (
(total_val_loss / total_val_tokens).item()
if total_val_tokens > 0
else float("inf")
)
val_pbar.set_description(
f"Running validation Loss: {avg_loss_so_far:.4f}"
)
# Aggregate validation metrics across all ranks
torch.distributed.all_reduce(total_val_loss)
torch.distributed.all_reduce(total_val_tokens)
avg_val_loss = (
(total_val_loss / total_val_tokens).item()
if total_val_tokens > 0
else float("inf")
)
for m in self.model_parts:
m.train()
print(f"\nValidation loss: {avg_val_loss}")

def cleanup(self) -> None:
if self.checkpointer:
self.checkpointer.close()
Expand Down
Loading
Loading