Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ training:
steps: 1000
compile: false
dataset: "c4"
#eval_interval: 500 # Setting eval_interval to run evaluation
#eval_steps: 100 # Number of validation batches during each evaluation run

parallelism:
data_parallel_replicate_degree: 1
Expand Down
178 changes: 137 additions & 41 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""To run:

python -m apps.sft.main --config apps/sft/llama3_8b.yaml

"""

import asyncio
Expand Down Expand Up @@ -40,8 +39,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

# from tqdm import tqdm

# stubs for now
Checkpointer = Any
Dataloader = Any
Expand All @@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine):
checkpointer: Checkpointer
tokenizer: Tokenizer
train_dataloader: Dataloader
# val_dataloader: Dataloader
val_dataloader: Dataloader
metric_logger: MetricLogger
profiler: Profiler
device: torch.device
Expand All @@ -81,6 +78,11 @@ def __init__(self, config: DictConfig):
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())

# Evaluation settings
self.eval_interval = job_config.training.get("eval_interval", float("inf"))
self.eval_steps = job_config.training.get("eval_steps", 0)

self._init_dist()
super().__init__(job_config)

Expand Down Expand Up @@ -111,25 +113,23 @@ def _init_dist(self):

@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()
# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
# self.train_config.packing_config,
# )
# self.val_dataloader = self.setup_data(
# self.train_config.val_dataset_config,
# self.train_config.val_dataloader_config,
# self.train_config.packing_config,
# )

# TODO: confirm that this is working properly
# Should also use load, not dcp_load
# Setup training data (first 90% of train split)
self.train_dataloader = self.setup_data(
dataset_path="yahma/alpaca-cleaned", dataset_split="train[:90%]"
)

# Setup validation data (last 10% of train split)
self.val_dataloader = self.setup_data(
dataset_path="yahma/alpaca-cleaned", dataset_split="train[90%:]"
)

# Load checkpoint if resuming
self.checkpointer.load(step=self.current_step)
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
def setup_data(
self, dataset_path: str = "yahma/alpaca-cleaned", dataset_split: str = "train"
):
"""Setup data with configurable dataset path and split."""
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
Expand All @@ -146,8 +146,8 @@ def setup_data(self):
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We can have this implemented as soon as we fix the main eval functionality

path="yahma/alpaca-cleaned",
split="train",
path=dataset_path,
split=dataset_split,
)
packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
Expand All @@ -163,10 +163,6 @@ def setup_data(self):
),
)

# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader

def forward_backward(
Expand Down Expand Up @@ -206,7 +202,6 @@ def forward_backward(
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
Expand All @@ -225,27 +220,125 @@ def forward_backward(

return loss

def forward_only(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
"""Forward pass only for evaluation (no backward)."""
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs = input_dict["tokens"]
optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward only
with self.train_context(optional_context_parallel_ctx):
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs, target=targets, losses=losses, input_batch=inputs
)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
)

loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
)
else:
# Non-PP forward only
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
del pred

return loss

def train_step(self, batch) -> None:
# TODO
# with GradientAccumulation(
# self.gradient_accumulation_steps,
# self.model,
# self.data_parallel_size,
# ) as grad_acc:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)

logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
self.optimizers.step()
self.lr_schedulers.step()

async def evaluate(self) -> dict[str, float]:
"""Run evaluation on validation set (internal method, not an endpoint)."""
logger.info("=" * 50)
logger.info("STARTING EVALUATION ")
logger.info("=" * 50)

# Set model to eval mode
for model_part in self.model_parts:
model_part.eval()

val_dataloader = iter(self.val_dataloader)
total_loss = 0.0
num_batches = 0

with torch.no_grad():
for step in range(self.eval_steps):
try:
batch = next(val_dataloader)

# Move tensors to device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)

labels = batch.pop("labels")
loss = self.forward_only(batch, labels)

total_loss += loss.item()
num_batches += 1

logger.info(
f" Eval batch {num_batches}/{self.eval_steps} | Loss: {loss.item():.4f}"
)

except StopIteration:
logger.warning("Reached end of validation dataloader early")
break

# Set model back to train mode
for model_part in self.model_parts:
model_part.train()

avg_loss = total_loss / max(num_batches, 1)

metrics = {
"val_loss": avg_loss,
"val_batches": num_batches,
}

logger.info("-" * 50)
logger.info(f"EVALUATION COMPLETE")
logger.info(f"Validation Loss: {avg_loss:.4f}")
logger.info(f"Batches Evaluated: {num_batches}")
logger.info("=" * 50)
return metrics

@endpoint
async def train(self) -> None:
dataloader = iter(self.train_dataloader)
self.optimizers.zero_grad()

# TODO: tqdm is broken in Monarch actors
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)

Expand All @@ -254,18 +347,21 @@ async def train(self) -> None:
# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now
batch[k] = v.to(self.device) # TODO: hardcoded for now
self.train_step(batch)
# self.profiler.step()
self.current_step += 1

# Run evaluation periodically
if self.current_step % self.eval_interval == 0:
eval_metrics = await self.evaluate()
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}")

# Save checkpoints
self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)

# self.pbar.close()

@endpoint
async def cleanup(self) -> None:
if self.checkpointer:
Expand Down
Loading