Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion apps/sft/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,27 @@ optimizer:
lr_scheduler:
warmup_steps: 200

dataset:
path: "yahma/alpaca-cleaned"
split: "train[:95%]"

dataset_val:
path: "yahma/alpaca-cleaned"
split: "train[95%:]"

training:
local_batch_size: 1
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"


validation:
enabled: true # Enable/disable validation
eval_interval: 100 # Run evaluation every 100 training steps
eval_steps: 50 # Number of batches per evaluation (0 = full epoch)


parallelism:
data_parallel_replicate_degree: 1
Expand Down
249 changes: 208 additions & 41 deletions apps/sft/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
"""To run:

python -m apps.sft.main --config apps/sft/llama3_8b.yaml

"""

import asyncio
Expand Down Expand Up @@ -40,8 +39,6 @@
from torchtitan.experiments.forge.engine import ForgeEngine
from torchtitan.experiments.forge.job_config import ForgeJobConfig

# from tqdm import tqdm

# stubs for now
Checkpointer = Any
Dataloader = Any
Expand All @@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine):
checkpointer: Checkpointer
tokenizer: Tokenizer
train_dataloader: Dataloader
# val_dataloader: Dataloader
val_dataloader: Dataloader
metric_logger: MetricLogger
profiler: Profiler
device: torch.device
Expand All @@ -81,6 +78,27 @@ def __init__(self, config: DictConfig):
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
self._rank = current_rank().rank
self._size = math.prod(current_size().values())

# Evaluation settings from validation config
validation_config = job_config.get("validation", {})
self.validation_enabled = validation_config.get("enabled", False)

if self.validation_enabled:
self.eval_interval = validation_config.get("eval_interval")
self.eval_steps = validation_config.get("eval_steps")

if self.eval_interval is None:
raise ValueError(
"validation.eval_interval is required when validation.enabled is true"
)
if self.eval_steps is None:
raise ValueError(
"validation.eval_steps is required when validation.enabled is true"
)
else:
self.eval_interval = None
self.eval_steps = None

self._init_dist()
super().__init__(job_config)

Expand Down Expand Up @@ -111,25 +129,30 @@ def _init_dist(self):

@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()
# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
# self.train_config.packing_config,
# )
# self.val_dataloader = self.setup_data(
# self.train_config.val_dataset_config,
# self.train_config.val_dataloader_config,
# self.train_config.packing_config,
# )

# TODO: confirm that this is working properly
# Should also use load, not dcp_load
# Setup training data from config
dataset_config = self.job_config.get("dataset")

self.train_dataloader = self.setup_data(
dataset_path=dataset_config.get("path"),
dataset_split=dataset_config.get("split"),
)

# Setup validation data from config
dataset_val_config = self.job_config.get("dataset_val", {})
self.val_dataloader = self.setup_data(
dataset_path=dataset_val_config.get("path", dataset_config.get("path")),
dataset_split=dataset_val_config.get("split", dataset_config.get("split")),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit but I don't like these nested .get calls. It also seems strange that we would fallback to validation on the training set. Personally I would just recommend checking if validation is enabled, and if it's not, don't even set up the validation dataloader at all.

)

# Load checkpoint if resuming
self.checkpointer.load(step=self.current_step)
# self.profiler = self.setup_profiler(self.train_config.profiler_config)
# self.logger = self.setup_logger(self.train_config.logger_config)

def setup_data(self):
def setup_data(self, dataset_path: str, dataset_split: str):
"""Setup data with configurable dataset path and split."""
if not dataset_path or not dataset_split:
raise ValueError(
f"dataset.path and dataset.split are required in YAML config. Got path={dataset_path}, split={dataset_split}"
)
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
tokenizer = HuggingFaceModelTokenizer(
tokenizer_json_path=os.path.join(
Expand All @@ -146,8 +169,8 @@ def setup_data(self):
dataset = sft_iterable_dataset(
model_transform=tokenizer,
message_transform=AlpacaToMessages(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally we shouldn't hardcode this either (but it's a bit more work without instantiate)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. We can have this implemented as soon as we fix the main eval functionality

path="yahma/alpaca-cleaned",
split="train",
path=dataset_path,
split=dataset_split,
)
packer = TextPacker(padding_idx=0)
dataset = PackedDataset(
Expand All @@ -163,10 +186,6 @@ def setup_data(self):
),
)

# Ultimately we probably want something like this
# packer = build_packing_strategy(packing_config)
# dataset = build_dataset(dataset_config)
# dataloader = build_dataloader(dataloader_config, dataset, packer)
return dataloader

def forward_backward(
Expand Down Expand Up @@ -206,7 +225,6 @@ def forward_backward(
)

# accumulate losses across pipeline microbatches
# TODO: PP+FSDP unexpectedly puts the loss back to the CPU
loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
Expand All @@ -225,27 +243,173 @@ def forward_backward(

return loss

def forward_only(
self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor
) -> torch.Tensor:
"""Forward pass only for evaluation (no backward)."""
model_parts = self.model_parts
parallel_dims = self.parallel_dims

inputs = input_dict["tokens"]
optional_context_parallel_ctx = (
dist_utils.create_context_parallel_ctx(
cp_mesh=parallel_dims.world_mesh["cp"],
cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts],
cp_seq_dims=[1, 1] + [0 for _ in model_parts],
cp_no_restore_buffers={inputs, labels},
cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method,
)
if parallel_dims.cp_enabled
else None
)

if parallel_dims.pp_enabled:
# Pipeline Parallel forward only
with self.train_context(optional_context_parallel_ctx):
targets, losses = (
(labels, []) if self.pp_has_last_stage else (None, None)
)
if self.pp_has_first_stage:
self.pp_schedule.step(
inputs, target=targets, losses=losses, input_batch=inputs
)
else:
self.pp_schedule.step(
target=targets, losses=losses, input_batch=inputs
)

loss = (
torch.mean(torch.stack(losses)).to(self.device)
if self.pp_has_last_stage
else torch.tensor([-1.0], device=self.device)
)
else:
# Non-PP forward only
with self.train_context(optional_context_parallel_ctx):
assert len(model_parts) == 1
with self.maybe_enable_amp:
pred = model_parts[0](inputs)
loss = self.loss_fn(pred, labels)
del pred

return loss

def train_step(self, batch) -> None:
# TODO
# with GradientAccumulation(
# self.gradient_accumulation_steps,
# self.model,
# self.data_parallel_size,
# ) as grad_acc:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)

logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
self.optimizers.step()
self.lr_schedulers.step()

def _extract_epoch_from_batch(self, batch: dict) -> int | None:
"""Extract epoch number from batch metrics."""
if "metrics" in batch:
for metric in batch["metrics"]:
if (
hasattr(metric, "metric_name")
and metric.metric_name == "num_epochs"
):
return metric.value
return None

async def evaluate(self) -> dict[str, float]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In addition to @felipemello1's more detailed comments, one higher-level point: this eval implementation is adding a lot of code to the main.py file, which as the entry point is something that everyone will have to read. (Specifically this PR alone has increased the total LoC by more than 50%, and the evaluate method alone is more than 100 lines due to boundary checking, edge case handling, etc.) I would like to see if we can find a more minimal way to introduce eval that doesn't expose the user to so much code complexity.

I think Felipe's suggestions of offloading to the dataset class, utilities, etc. are valuable. But would also like to re-raise the option of simplifying by only allowing eval for a fixed number of steps (at least for a first pass). Not gonna block on this, if we can do the cross-epoch accounting in a bit more of a clean, minimal way I am all for it.

"""Run evaluation with async all_reduce for cross-rank epoch synchronization."""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it might be worth enhancing this docstring a bit. Maybe add a small numerical example.

logger.info("=" * 50)
logger.info("STARTING EVALUATION")
logger.info("=" * 50)

for model_part in self.model_parts:
model_part.eval()

val_dataloader = iter(self.val_dataloader)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you have a chance to consider the case for multidataset? what happens in this case?

total_loss, num_batches, starting_epoch = 0.0, 0, None

# Prefetch first batch
try:
next_batch = next(val_dataloader)
except StopIteration:
Copy link
Contributor

@felipemello1 felipemello1 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we can remove the defensive checks and assume that the dataset is infinite. We have a class for it. I think you can just do an assertion that its TuneIterableDataset (we need to update the name and remove "tune", but dont worry about this on this PR:

class InfiniteTuneIterableDataset(TuneIterableDataset):

Then we know its infinite, and we can remove try/except here and later in the loop. wdyt?

logger.warning("Validation dataloader is empty")
return {"val_loss": 0.0, "val_batches": 0}

should_break, pending_work, epoch_tensor = False, None, None

with torch.no_grad():
while True:
# Wait for previous async all_reduce to complete
if pending_work is not None:
Copy link
Contributor

@felipemello1 felipemello1 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am thinking we could abstract most of it into some utility and have this (feel free to change var names)

epoch_incremented, next_max_epoch = False, None
with torch.no_grad():
     while True:
        # check if epoch incremented before getting new batch.
        # If so, stop iterating on the dataset
        epoch_incremented: bool = check_if_epoch_incremented(batch, next_max_epoch)
        if epoch_incremented:
             logger.info("bla bla bla")
             break
        
        # get next batch
		batch = next_batch
		next_batch = next(val_dataloader)

        # start non-blocking all_reduce for next batches epoch
        next_max_epoch: futures = get_distributed_max_epoch(next_batch)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not 100% sure this works. I think that get_distributed_max_epoch may need to return a tensor and futures?

pending_work.wait()
should_break = (
epoch_tensor.item() > 0 if epoch_tensor is not None else False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a rule of thumb, its not good to use item(), because it requires cpu-gpu synchonization

)
pending_work = None

if should_break:
logger.info(
"Epoch completed across all ranks - stopping evaluation"
)
break

if self.eval_steps > 0 and num_batches >= self.eval_steps:
logger.info(f"Reached eval_steps cap of {self.eval_steps}")
break

batch = next_batch

# Track starting epoch
current_epoch = self._extract_epoch_from_batch(batch)
if current_epoch is not None and starting_epoch is None:
starting_epoch = current_epoch

# Prefetch next batch and start async epoch check
try:
next_batch = next(val_dataloader)
next_epoch = self._extract_epoch_from_batch(next_batch)

if next_epoch is not None and starting_epoch is not None:
Copy link
Contributor

@felipemello1 felipemello1 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i am tempted to say lets remove the if/else check for None. Do you see a strong argument for keeping them? It would only be a problem if someone replaced our dataset abstraction, otherwise we always have the metric

epoch_increment = next_epoch - starting_epoch
if torch.distributed.is_initialized():
epoch_tensor = torch.tensor(
[epoch_increment], dtype=torch.long, device=self.device
)
pending_work = torch.distributed.all_reduce(
epoch_tensor,
op=torch.distributed.ReduceOp.MAX,
async_op=True,
)
else:
should_break = epoch_increment > 0
except StopIteration:
should_break = True

# Process current batch (overlaps with async all_reduce)
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to(self.device)

labels = batch.pop("labels")
loss = self.forward_only(batch, labels)
total_loss += loss.item()
num_batches += 1

if num_batches % 10 == 0:
logger.info(f" Eval batch {num_batches} | Loss: {loss.item():.4f}")

for model_part in self.model_parts:
model_part.train()

avg_loss = total_loss / max(num_batches, 1)
Copy link
Contributor

@felipemello1 felipemello1 Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might be incorrect. If we have one sample with 10 tokens and another sample with 100 tokens, we have to /110, not by 2. @ebsmothers can you confirm?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This kinda depends. One interpretation is that the val loss we report is the average over all batches in the val dataloader. In that case this would be the correct implementation. A second interpretation would be that the val loss we report is actually the aggregate loss over all tokens in the val dataset. This is the more "correct" one, as @felipemello1 points out. It especially matters for training.. btw this is why I did not enable gradient accumulation initially, as it needs some special handling. See e.g. the "Long digression" section in the summary of meta-pytorch/torchtune#1917. Since then titan has made some strides in this direction, see this function and this issue.

So I think the ideal state is that we use token-normalized loss (rather than batch-normalized as has been done here) for both training and validation. But for training it will likely require a little more work to fully support in titan. For validation it's more straightforward, you can see this snippet in torchtune's validation logic.

logger.info(
f"EVALUATION COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}"
)
logger.info("=" * 50)

return {"val_loss": avg_loss, "val_batches": num_batches}

@endpoint
async def train(self) -> None:
dataloader = iter(self.train_dataloader)
self.optimizers.zero_grad()

# TODO: tqdm is broken in Monarch actors
# self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps)

Expand All @@ -254,18 +418,21 @@ async def train(self) -> None:
# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now
batch[k] = v.to(self.device) # TODO: hardcoded for now
self.train_step(batch)
# self.profiler.step()
self.current_step += 1

# Run evaluation periodically if enabled
if self.validation_enabled and self.current_step % self.eval_interval == 0:
eval_metrics = await self.evaluate()
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}")

# Save checkpoints
self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
)

# self.pbar.close()

@endpoint
async def cleanup(self) -> None:
if self.checkpointer:
Expand Down
15 changes: 14 additions & 1 deletion apps/sft/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,26 @@ optimizer:
lr_scheduler:
warmup_steps: 200

# Dataset configuration
dataset:
path: "yahma/alpaca-cleaned"
split: "train[:95%]"

dataset_val:
path: "yahma/alpaca-cleaned"
split: "train[95%:]"

training:
local_batch_size: 1
seq_len: 2048
max_norm: 1.0
steps: 1000
compile: false
dataset: "c4"

validation:
enabled: true # Enable/disable validation
eval_interval: 100 # Run evaluation every 100 training steps
eval_steps: 50 # Number of batches per evaluation (0 = full epoch)

parallelism:
data_parallel_replicate_degree: 1
Expand Down
Loading
Loading