diff --git a/recipes/full_finetune_distributed.py b/recipes/full_finetune_distributed.py index 9e30a61329..167f77ed31 100644 --- a/recipes/full_finetune_distributed.py +++ b/recipes/full_finetune_distributed.py @@ -30,6 +30,7 @@ from torchtune.modules.embedding_utils import resize_token_embeddings from torchtune.modules.loss import SFTLoss from torchtune.modules.moe import utils as moe_utils +from torchtune.modules.optim import OptimizerInBackward from torchtune.recipe_interfaces import FTRecipeInterface from torchtune.training import ( DummyProfiler, @@ -378,7 +379,6 @@ def setup(self, cfg: DictConfig) -> None: self._optimizer = self._setup_optimizer( cfg_optimizer=cfg.optimizer, - optimizer_in_bwd=self._optimizer_in_bwd, opt_state_dict=( checkpoint_dict[training.OPT_KEY] if training.OPT_KEY in checkpoint_dict @@ -404,12 +404,7 @@ def setup(self, cfg: DictConfig) -> None: try: checkpoint_dict = ( self._checkpoint_client.load_distributed_checkpoint( - self._model, - ( - self._optim_ckpt_wrapper - if self._optimizer_in_bwd - else self._optimizer - ), + self._model, self._optimizer ) ) except Exception as e: @@ -511,25 +506,13 @@ def _setup_lr_scheduler( ) return None - if self._optimizer_in_bwd: - # Use the first optimizer from the wrapper to represent the learning rate - optimizer = next(iter(self._optim_ckpt_wrapper.optim_map.values())) - else: - # Standard case: use the single optimizer - optimizer = self._optimizer - - # Instantiate the learning rate scheduler lr_scheduler = config.instantiate( cfg_lr_scheduler, - optimizer, + self._optimizer, num_training_steps=num_training_steps, last_epoch=last_epoch, ) - if self._optimizer_in_bwd: - # Modify the scheduler for optimizer_in_bwd case - self._optim_ckpt_wrapper.set_lr_scheduler(lr_scheduler) - if self._is_rank_zero: self._logger.info("Learning rate scheduler is initialized.") @@ -726,55 +709,26 @@ def _setup_model( def _setup_optimizer( self, cfg_optimizer: DictConfig, - optimizer_in_bwd: bool = False, opt_state_dict: Optional[dict[str, Any]] = None, - ) -> Optional[Optimizer]: - if optimizer_in_bwd: - # Maintain a dict of optims for every parameter. - optim_dict = { - param: config.instantiate(cfg_optimizer, [param]) - for param in self._model.parameters() - } - - # Register optimizer step hooks on the model to run optimizer in backward. - training.register_optim_in_bwd_hooks( - model=self._model, optim_dict=optim_dict - ) - # Create a wrapper for checkpoint save/load of optimizer states when running in backward. - self._optim_ckpt_wrapper = training.create_optim_in_bwd_wrapper( - model=self._model, optim_dict=optim_dict + ) -> Optimizer: + if self._optimizer_in_bwd: + optimizer_cls = _get_component_from_path(cfg_optimizer.pop("_component_")) + optimizer = OptimizerInBackward( + params=self._model.parameters(), + optimizer_cls=optimizer_cls, + **cfg_optimizer, ) - # Load optimizer states for each param. If optimizer states are being restored in an optimizer in - # backward run, these need to have been saved with the same setting. Cannot restore from runs that - # did not use optimizer in backward. - if opt_state_dict is not None: - for param in opt_state_dict.keys(): - try: - training.load_from_full_optimizer_state_dict( - self._model, - self._optim_ckpt_wrapper.optim_map[param], - opt_state_dict[param], - self._device, - ) - except BaseException as e: - raise RuntimeError( - "Failed loading in-backward optimizer checkpoints." - "Please make sure run being restored from was using in-backward optimizer." - ) from e - utils.log_rank_zero(self._logger, "In-backward optimizers are set up.") - return None else: optimizer = config.instantiate(cfg_optimizer, self._model.parameters()) - if opt_state_dict: - training.load_from_full_optimizer_state_dict( - self._model, - optimizer, - opt_state_dict, - self._device, - ) - - utils.log_rank_zero(self._logger, "Optimizer is initialized.") - return optimizer + if opt_state_dict: + training.load_from_full_optimizer_state_dict( + self._model, + optimizer, + opt_state_dict, + self._device, + ) + utils.log_rank_zero(self._logger, "Optimizer is initialized.") + return optimizer def _setup_data( self, @@ -859,7 +813,7 @@ def validate(self) -> dict[str, float]: total_val_tokens = torch.tensor(0.0, device=self._device) with torch.no_grad(): - for batch_idx, batch in enumerate(self._val_dataloader): + for _, batch in enumerate(self._val_dataloader): utils.batch_to_device(batch, self._device) # Count tokens excluding padding @@ -895,25 +849,12 @@ def validate(self) -> dict[str, float]: return log_dict def train(self) -> None: - """ - The core training loop. - """ - # clean up before training begins training.cleanup_before_training() - - # zero out the gradients before starting training - if not self._optimizer_in_bwd: - self._optimizer.zero_grad() - else: - for opt in self._optim_ckpt_wrapper.optim_map.values(): - opt.zero_grad() - - # Initialize tokens count and running loss (for grad accumulation) + self._optimizer.zero_grad() t0 = time.perf_counter() - running_loss = 0 - num_tokens = 0 - + running_loss, num_tokens = 0.0, 0 self._profiler.start() + # self.epochs_run should be non-zero when we're resuming from a checkpoint for curr_epoch in range(self.epochs_run, self.total_epochs): pbar = tqdm(total=self._steps_per_epoch, disable=not self._is_rank_zero) @@ -944,19 +885,17 @@ def train(self) -> None: current_loss = self._loss_step(batch) * current_num_tokens running_loss += current_loss # For optimizer in backward, we need to normalize before calling backward - # This case and gradient accumulation are mutually exclusive if self._optimizer_in_bwd: torch.distributed.all_reduce(num_tokens) torch.distributed.all_reduce(running_loss) - current_loss = current_loss * (self.dp_degree / num_tokens) + current_loss = current_loss / num_tokens current_loss.backward() # Optimizer step (if not fused in backward call) if (idx + 1) % self._gradient_accumulation_steps == 0: + grad_norm = None if not self._optimizer_in_bwd: - # Get total number of tokens across all ranks to normalize gradients torch.distributed.all_reduce(num_tokens) - # This will ensure that the logged loss matches what we're optimizing torch.distributed.all_reduce(running_loss) # Manually scale the gradients from unnormalized loss by total # of tokens @@ -974,16 +913,15 @@ def train(self) -> None: # If sharded, collect the DTensor here if isinstance(grad_norm, DTensor): grad_norm = grad_norm.full_tensor() - self._optimizer.step() - self._optimizer.zero_grad(set_to_none=True) - # Update the number of steps when the weights are updated - self.global_step += 1 + self._optimizer.step() + self._optimizer.zero_grad(set_to_none=True) - # Step the learning rate scheduler if self._lr_scheduler is not None: self._lr_scheduler.step() + self.global_step += 1 + # If float8 training is enabled, perform a single all-reduce to compute the # scale for all float8 parameters efficiently instead of doing many small # all-reduces for each parameter @@ -1008,13 +946,7 @@ def train(self) -> None: time_per_step = time.perf_counter() - t0 log_dict = { "loss": loss_to_log, - "lr": get_lr( - ( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), - ), + "lr": get_lr(self._optimizer), "tokens_per_second_per_gpu": ( num_tokens / self.parallel_dims.non_data_parallel_size ) @@ -1070,11 +1002,7 @@ def train(self) -> None: self.epochs_run += 1 self._checkpoint_client.save_checkpoint( model=self._model, - optimizer=( - self._optimizer - if not self._optimizer_in_bwd - else self._optim_ckpt_wrapper - ), + optimizer=self._optimizer, training_progress=TrainingProgress( seed=self.seed, epochs_run=self.epochs_run,