Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,27 @@ optimizer:
lr_scheduler:
warmup_steps: 200

dataset:
path: "yahma/alpaca-cleaned"
split: "train[:95%]"

dataset_val:
path: "yahma/alpaca-cleaned"
split: "train[95%:]"

training:
local_batch_size: 1
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"


validation:
enabled: true # Enable/disable validation
eval_interval: 100 # Run evaluation every 100 training steps
eval_steps: 50 # Number of batches per evaluation (0 = full epoch)


parallelism:
data_parallel_replicate_degree: 1
Expand Down
249 changes: 208 additions & 41 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""To run:

python -m apps.sft.main --config apps/sft/llama3_8b.yaml

"""

import asyncio
Expand Down Expand Up @@ -40,8 +39,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

# from tqdm import tqdm

# stubs for now
Checkpointer = Any
Dataloader = Any
Expand All @@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine):
checkpointer: Checkpointer
tokenizer: Tokenizer
train_dataloader: Dataloader
# val_dataloader: Dataloader
val_dataloader: Dataloader
metric_logger: MetricLogger
profiler: Profiler
device: torch.device
Expand All @@ -81,6 +78,27 @@ def __init__(self, config: DictConfig):
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())

# Evaluation settings from validation config
validation_config = job_config.get("validation", {})
self.validation_enabled = validation_config.get("enabled", False)

if self.validation_enabled:
self.eval_interval = validation_config.get("eval_interval")
self.eval_steps = validation_config.get("eval_steps")

if self.eval_interval is None:
raise ValueError(
"validation.eval_interval is required when validation.enabled is true"
)
if self.eval_steps is None:
raise ValueError(
"validation.eval_steps is required when validation.enabled is true"
)
else:
self.eval_interval = None
self.eval_steps = None

self._init_dist()
super().__init__(job_config)

Expand Down Expand Up @@ -111,25 +129,30 @@ def _init_dist(self):

@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()
# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
# self.train_config.packing_config,
# )
# self.val_dataloader = self.setup_data(
# self.train_config.val_dataset_config,
# self.train_config.val_dataloader_config,
# self.train_config.packing_config,
# )

# TODO: confirm that this is working properly
# Should also use load, not dcp_load
# Setup training data from config
dataset_config = self.job_config.get("dataset")

self.train_dataloader = self.setup_data(
dataset_path=dataset_config.get("path"),
dataset_split=dataset_config.get("split"),
)

# Setup validation data from config
dataset_val_config = self.job_config.get("dataset_val", {})
self.val_dataloader = self.setup_data(
dataset_path=dataset_val_config.get("path", dataset_config.get("path")),
dataset_split=dataset_val_config.get("split", dataset_config.get("split")),
)

# Load checkpoint if resuming
self.checkpointer.load(step=self.current_step)
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
def setup_data(self, dataset_path: str, dataset_split: str):
"""Setup data with configurable dataset path and split."""
if not dataset_path or not dataset_split:
raise ValueError(
f"dataset.path and dataset.split are required in YAML config. Got path={dataset_path}, split={dataset_split}"
)
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
Expand All @@ -146,8 +169,8 @@ def setup_data(self):
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We can have this implemented as soon as we fix the main eval functionality

path="yahma/alpaca-cleaned",
split="train",
path=dataset_path,
split=dataset_split,
)
packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
Expand All @@ -163,10 +186,6 @@ def setup_data(self):
),
)

# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader

def forward_backward(
Expand Down Expand Up @@ -206,7 +225,6 @@ def forward_backward(
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
Expand All @@ -225,27 +243,173 @@ def forward_backward(

return loss

def forward_only(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
"""Forward pass only for evaluation (no backward)."""
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs = input_dict["tokens"]
optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward only
with self.train_context(optional_context_parallel_ctx):
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs, target=targets, losses=losses, input_batch=inputs
)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
)

loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
)
else:
# Non-PP forward only
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
del pred

return loss

def train_step(self, batch) -> None:
# TODO
# with GradientAccumulation(
# self.gradient_accumulation_steps,
# self.model,
# self.data_parallel_size,
# ) as grad_acc:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)

logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
self.optimizers.step()
self.lr_schedulers.step()

def _extract_epoch_from_batch(self, batch: dict) -> int | None:
"""Extract epoch number from batch metrics."""
if "metrics" in batch:
for metric in batch["metrics"]:
if (
hasattr(metric, "metric_name")
and metric.metric_name == "num_epochs"
):
return metric.value
return None

async def evaluate(self) -> dict[str, float]:
"""Run evaluation with async all_reduce for cross-rank epoch synchronization."""
logger.info("=" * 50)
logger.info("STARTING EVALUATION")
logger.info("=" * 50)

for model_part in self.model_parts:
model_part.eval()

val_dataloader = iter(self.val_dataloader)
total_loss, num_batches, starting_epoch = 0.0, 0, None

# Prefetch first batch
try:
next_batch = next(val_dataloader)
except StopIteration:
logger.warning("Validation dataloader is empty")
return {"val_loss": 0.0, "val_batches": 0}

should_break, pending_work, epoch_tensor = False, None, None

with torch.no_grad():
while True:
# Wait for previous async all_reduce to complete
if pending_work is not None:
pending_work.wait()
should_break = (
epoch_tensor.item() > 0 if epoch_tensor is not None else False
)
pending_work = None

if should_break:
logger.info(
"Epoch completed across all ranks - stopping evaluation"
)
break

if self.eval_steps > 0 and num_batches >= self.eval_steps:
logger.info(f"Reached eval_steps cap of {self.eval_steps}")
break

batch = next_batch

# Track starting epoch
current_epoch = self._extract_epoch_from_batch(batch)
if current_epoch is not None and starting_epoch is None:
starting_epoch = current_epoch

# Prefetch next batch and start async epoch check
try:
next_batch = next(val_dataloader)
next_epoch = self._extract_epoch_from_batch(next_batch)

if next_epoch is not None and starting_epoch is not None:
epoch_increment = next_epoch - starting_epoch
if torch.distributed.is_initialized():
epoch_tensor = torch.tensor(
[epoch_increment], dtype=torch.long, device=self.device
)
pending_work = torch.distributed.all_reduce(
epoch_tensor,
op=torch.distributed.ReduceOp.MAX,
async_op=True,
)
else:
should_break = epoch_increment > 0
except StopIteration:
should_break = True

# Process current batch (overlaps with async all_reduce)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)

labels = batch.pop("labels")
loss = self.forward_only(batch, labels)
total_loss += loss.item()
num_batches += 1

if num_batches % 10 == 0:
logger.info(f" Eval batch {num_batches} | Loss: {loss.item():.4f}")

for model_part in self.model_parts:
model_part.train()

avg_loss = total_loss / max(num_batches, 1)
logger.info(
f"EVALUATION COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}"
)
logger.info("=" * 50)

return {"val_loss": avg_loss, "val_batches": num_batches}

@endpoint
async def train(self) -> None:
dataloader = iter(self.train_dataloader)
self.optimizers.zero_grad()

# TODO: tqdm is broken in Monarch actors
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)

Expand All @@ -254,18 +418,21 @@ async def train(self) -> None:
# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now
batch[k] = v.to(self.device) # TODO: hardcoded for now
self.train_step(batch)
# self.profiler.step()
self.current_step += 1

# Run evaluation periodically if enabled
if self.validation_enabled and self.current_step % self.eval_interval == 0:
eval_metrics = await self.evaluate()
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}")

# Save checkpoints
self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)

# self.pbar.close()

@endpoint
async def cleanup(self) -> None:
if self.checkpointer:
Expand Down
15 changes: 14 additions & 1 deletion apps/sft/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,26 @@ optimizer:
lr_scheduler:
warmup_steps: 200

# Dataset configuration
dataset:
path: "yahma/alpaca-cleaned"
split: "train[:95%]"

dataset_val:
path: "yahma/alpaca-cleaned"
split: "train[95%:]"

training:
local_batch_size: 1
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"

validation:
enabled: true # Enable/disable validation
eval_interval: 100 # Run evaluation every 100 training steps
eval_steps: 50 # Number of batches per evaluation (0 = full epoch)

parallelism:
data_parallel_replicate_degree: 1
Expand Down
Loading
Loading