Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 31 additions & 103 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from torchtune.modules.embedding_utils import resize_token_embeddings
from torchtune.modules.loss import SFTLoss
from torchtune.modules.moe import utils as moe_utils
from torchtune.modules.optim import OptimizerInBackward
from torchtune.recipe_interfaces import FTRecipeInterface
from torchtune.training import (
DummyProfiler,
Expand Down Expand Up @@ -378,7 +379,6 @@ def setup(self, cfg: DictConfig) -> None:

self._optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
optimizer_in_bwd=self._optimizer_in_bwd,
opt_state_dict=(
checkpoint_dict[training.OPT_KEY]
if training.OPT_KEY in checkpoint_dict
Expand All @@ -404,12 +404,7 @@ def setup(self, cfg: DictConfig) -> None:
try:
checkpoint_dict = (
self._checkpoint_client.load_distributed_checkpoint(
self._model,
(
self._optim_ckpt_wrapper
if self._optimizer_in_bwd
else self._optimizer
),
self._model, self._optimizer
)
)
except Exception as e:
Expand Down Expand Up @@ -511,25 +506,13 @@ def _setup_lr_scheduler(
)
return None

if self._optimizer_in_bwd:
# Use the first optimizer from the wrapper to represent the learning rate
optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values()))
else:
# Standard case: use the single optimizer
optimizer = self._optimizer

# Instantiate the learning rate scheduler
lr_scheduler = config.instantiate(
cfg_lr_scheduler,
optimizer,
self._optimizer,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)

if self._optimizer_in_bwd:
# Modify the scheduler for optimizer_in_bwd case
self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler)

if self._is_rank_zero:
self._logger.info("Learning rate scheduler is initialized.")

Expand Down Expand Up @@ -726,55 +709,26 @@ def _setup_model(
def _setup_optimizer(
self,
cfg_optimizer: DictConfig,
optimizer_in_bwd: bool = False,
opt_state_dict: Optional[dict[str, Any]] = None,
) -> Optional[Optimizer]:
if optimizer_in_bwd:
# Maintain a dict of optims for every parameter.
optim_dict = {
param: config.instantiate(cfg_optimizer, [param])
for param in self._model.parameters()
}

# Register optimizer step hooks on the model to run optimizer in backward.
training.register_optim_in_bwd_hooks(
model=self._model, optim_dict=optim_dict
)
# Create a wrapper for checkpoint save/load of optimizer states when running in backward.
self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper(
model=self._model, optim_dict=optim_dict
) -> Optimizer:
if self._optimizer_in_bwd:
optimizer_cls = _get_component_from_path(cfg_optimizer.pop("_component_"))
optimizer = OptimizerInBackward(
params=self._model.parameters(),
optimizer_cls=optimizer_cls,
**cfg_optimizer,
)
# Load optimizer states for each param. If optimizer states are being restored in an optimizer in
# backward run, these need to have been saved with the same setting. Cannot restore from runs that
# did not use optimizer in backward.
if opt_state_dict is not None:
for param in opt_state_dict.keys():
try:
training.load_from_full_optimizer_state_dict(
self._model,
self._optim_ckpt_wrapper.optim_map[param],
opt_state_dict[param],
self._device,
)
except BaseException as e:
raise RuntimeError(
"Failed loading in-backward optimizer checkpoints."
"Please make sure run being restored from was using in-backward optimizer."
) from e
utils.log_rank_zero(self._logger, "In-backward optimizers are set up.")
return None
else:
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
)

utils.log_rank_zero(self._logger, "Optimizer is initialized.")
return optimizer
if opt_state_dict:
training.load_from_full_optimizer_state_dict(
self._model,
optimizer,
opt_state_dict,
self._device,
)
utils.log_rank_zero(self._logger, "Optimizer is initialized.")
return optimizer

def _setup_data(
self,
Expand Down Expand Up @@ -859,7 +813,7 @@ def validate(self) -> dict[str, float]:
total_val_tokens = torch.tensor(0.0, device=self._device)

with torch.no_grad():
for batch_idx, batch in enumerate(self._val_dataloader):
for _, batch in enumerate(self._val_dataloader):
utils.batch_to_device(batch, self._device)

# Count tokens excluding padding
Expand Down Expand Up @@ -895,25 +849,12 @@ def validate(self) -> dict[str, float]:
return log_dict

def train(self) -> None:
"""
The core training loop.
"""
# clean up before training begins
training.cleanup_before_training()

# zero out the gradients before starting training
if not self._optimizer_in_bwd:
self._optimizer.zero_grad()
else:
for opt in self._optim_ckpt_wrapper.optim_map.values():
opt.zero_grad()

# Initialize tokens count and running loss (for grad accumulation)
self._optimizer.zero_grad()
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0

running_loss, num_tokens = 0.0, 0
self._profiler.start()

# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero)
Expand Down Expand Up @@ -944,19 +885,17 @@ def train(self) -> None:
current_loss = self._loss_step(batch) * current_num_tokens
running_loss += current_loss
# For optimizer in backward, we need to normalize before calling backward
# This case and gradient accumulation are mutually exclusive
if self._optimizer_in_bwd:
torch.distributed.all_reduce(num_tokens)
torch.distributed.all_reduce(running_loss)
current_loss = current_loss * (self.dp_degree / num_tokens)
current_loss = current_loss / num_tokens
current_loss.backward()

# Optimizer step (if not fused in backward call)
if (idx + 1) % self._gradient_accumulation_steps == 0:
grad_norm = None
if not self._optimizer_in_bwd:
# Get total number of tokens across all ranks to normalize gradients
torch.distributed.all_reduce(num_tokens)
# This will ensure that the logged loss matches what we're optimizing
torch.distributed.all_reduce(running_loss)

# Manually scale the gradients from unnormalized loss by total # of tokens
Expand All @@ -974,16 +913,15 @@ def train(self) -> None:
# If sharded, collect the DTensor here
if isinstance(grad_norm, DTensor):
grad_norm = grad_norm.full_tensor()
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Update the number of steps when the weights are updated
self.global_step += 1
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)

# Step the learning rate scheduler
if self._lr_scheduler is not None:
self._lr_scheduler.step()

self.global_step += 1

# If float8 training is enabled, perform a single all-reduce to compute the
# scale for all float8 parameters efficiently instead of doing many small
# all-reduces for each parameter
Expand All @@ -1008,13 +946,7 @@ def train(self) -> None:
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": get_lr(
(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
),
"lr": get_lr(self._optimizer),
"tokens_per_second_per_gpu": (
num_tokens / self.parallel_dims.non_data_parallel_size
)
Expand Down Expand Up @@ -1070,11 +1002,7 @@ def train(self) -> None:
self.epochs_run += 1
self._checkpoint_client.save_checkpoint(
model=self._model,
optimizer=(
self._optimizer
if not self._optimizer_in_bwd
else self._optim_ckpt_wrapper
),
optimizer=self._optimizer,
training_progress=TrainingProgress(
seed=self.seed,
epochs_run=self.epochs_run,
Expand Down
Loading