diff --git a/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml b/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml new file mode 100644 index 0000000000..c16388fb99 --- /dev/null +++ b/recipes/configs/qwen2/0.5B_full_single_device_muon.yaml @@ -0,0 +1,112 @@ +# Config for single device full finetuning in full_finetune_single_device.py +# using a Qwen2 0.5B +# +# This config assumes that you've run the following command before launching +# this run: +# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct +# +# To launch on a single device, run the following command from root: +# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device +# +# You can add specific overrides through the command line. For example +# to override the checkpointer directory while launching training +# you can run: +# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device checkpointer.checkpoint_dir= +# +# This config works only for training on single device. + +output_dir: /tmp/torchtune/qwen2_0_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference. + +# Tokenizer +tokenizer: + _component_: torchtune.models.qwen2.qwen2_tokenizer + path: /tmp/Qwen2-0.5B-Instruct/vocab.json + merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt + max_seq_len: null + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_cleaned_dataset + packed: False # True increases speed +seed: null +shuffle: False #True + +# Model Arguments +model: + _component_: torchtune.models.qwen2.qwen2_0_5b + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Qwen2-0.5B-Instruct + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: QWEN2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 3 +epochs: 1 +optimizer: + _component_: torchtune.modules.optim.SingleDeviceMuonWithAuxAdam + muon_lr: 0.02 + muon_momentum: 0.95 + weight_decay: 0 + adam_lr: 2e-5 + adam_betas: [0.9, 0.95] + adam_eps: 1e-10 + +loss: + _component_: torchtune.modules.loss.LinearCrossEntropyLoss +optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1 + +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cuda + +# Memory management +enable_activation_checkpointing: True # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Reduced precision +dtype: bf16 + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True +log_level: INFO # DEBUG, WARN, etc. + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/recipes/full_finetune_single_device.py b/recipes/full_finetune_single_device.py index cd232d797c..db2b7c7600 100644 --- a/recipes/full_finetune_single_device.py +++ b/recipes/full_finetune_single_device.py @@ -441,8 +441,10 @@ def _setup_optimizer( **cfg_optimizer, ) else: + optimizer_cls = cfg_optimizer["_component_"] + params = self._model.named_parameters() if 'muon' in optimizer_cls.lower() else self._model.parameters() optimizer = config.instantiate( - cfg_optimizer, params=self._model.parameters() + cfg_optimizer, params=params ) if opt_state_dict: optimizer.load_state_dict(opt_state_dict) diff --git a/torchtune/modules/optim.py b/torchtune/modules/optim.py index 4e7d53d45c..16fc0a1e87 100644 --- a/torchtune/modules/optim.py +++ b/torchtune/modules/optim.py @@ -8,6 +8,8 @@ import torch from torch.optim import Optimizer +import torch.distributed as dist +from torch.distributed.tensor import distribute_tensor, DTensor __all__ = ["OptimizerInBackward"] @@ -82,3 +84,204 @@ def load_state_dict(self, state_dict): ) for idx, opt in self._optimizers.items(): opt.load_state_dict(state_dict["optimizers"][str(idx)]) + +class Muon(torch.optim.Optimizer): + """ + Muon - MomentUm Orthogonalized by Newton-schulz + + Muon internally runs standard SGD-momentum, and then performs an orthogonalization post- + processing step, in which each 2D parameter's update is replaced with the nearest orthogonal + matrix. To efficiently orthogonalize each update, we use a Newton-Schulz iteration, which has + the advantage that it can be stably run in bfloat16 on the GPU. + + Some warnings: + - We believe this optimizer is unlikely to work well for training with small batch size. + - We believe it may not work well for finetuning pretrained models, but we haven't tested this. + + Arguments: + params: The parameters to be optimized. + lr: The learning rate. The updates will have spectral norm of `lr`. (0.02 is a good default) + momentum: The momentum used by the internal SGD. (0.95 is a good default) + nesterov: Whether to use Nesterov-style momentum in the internal SGD. (recommended) + ns_steps: The number of Newton-Schulz iterations to run. (6 is probably always enough) + adamw_lr: The learning rate for the internal AdamW. + adamw_betas: The betas for the internal AdamW. + adamw_eps: The epsilon for the internal AdamW. + adamw_wd: The weight decay for the internal AdamW. + """ + def __init__(self, params, muon_selector=None, lr=0.02, momentum=0.95, nesterov=True, ns_steps=6, + adamw_lr=3e-4, adamw_betas=[0.95, 0.95], adamw_eps=1e-8, adamw_wd=0): + + defaults = dict(lr=lr, momentum=momentum, nesterov=nesterov, ns_steps=ns_steps, + adamw_lr_ratio=adamw_lr/lr, adamw_betas=adamw_betas, + adamw_eps=adamw_eps, adamw_wd=adamw_wd) + + if muon_selector is None: + muon_selector = lambda name, param: ( + param.requires_grad and + param.ndim >= 2 and # Check if scalar + "embed" not in name.lower() and # Check if embedding layer + "tok" not in name.lower() and # Check if token embeddings + "head" not in name.lower() and # Check if output head + "bias" not in name.lower() # Check if bias term + ) + + named_params = list(params) + + muon_params = [p for n, p in named_params if muon_selector(n, p)] + adamw_params = [p for n, p in named_params if not muon_selector(n, p)] + + super().__init__([*muon_params, *adamw_params], defaults) + + # Sort parameters into those for which we will use Muon, and those for which we will not + # we cant pickle booleans for saving, so we will use 1=True, 0=False + def assign_muon(p): + if p.ndim >= 2 and p.size(0) < 10000: + self.state[p]['use_muon'] = 1 + else: + self.state[p]['use_muon'] = 0 + + if isinstance(muon_params[0], dict): + for group in muon_params: + for p in group['params']: + assign_muon(p) + else: + for p in muon_params: + assign_muon(p) + + def assign_adamw(p): + # Do not use Muon for parameters in adamw_params + self.state[p]['use_muon'] = 0 + + if len(adamw_params) and isinstance(adamw_params[0], dict): + for group in adamw_params: + for p in group['params']: + assign_adamw(p) + else: + for p in adamw_params: + assign_adamw(p) + + if torch.distributed.is_initialized(): + self.world_size = torch.distributed.get_world_size() + self.rank = torch.distributed.get_rank() + else: + self.world_size = 1 + self.rank = 0 + + def to_dist(self, x, from_local=False, **meta): + if from_local: + return DTensor.from_local( + x, + device_mesh=meta["device_mesh"], + placements=meta["placements"], + shape=meta["shape"], + stride=meta["stride"], + ) + else: + return distribute_tensor(x, device_mesh=meta["device_mesh"], placements=meta["placements"]) + + + def to_local(self, x, keep_sharded=False): + if isinstance(x, DTensor): + meta = dict( + device_mesh=x.device_mesh, + placements=x.placements, + shape=x.shape, + stride=x.stride(), + ) + if keep_sharded: + return x.to_local(), meta + else: + return x.full_tensor(), meta + + return x, None + + def zeropower_via_newtonschulz5(self, G, steps=10, eps=1e-7): + """ + Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a + quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose + of minimizing steps, it turns out to be empirically effective to keep increasing the slope at + zero even beyond the point where the iteration no longer converges all the way to one everywhere + on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T + where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model + performance at all relative to UV^T, where USV^T = G is the SVD. + """ + assert len(G.shape) == 2 + a, b, c = (3.4445, -4.7750, 2.0315) + X = G.bfloat16() + X /= (X.norm() + eps) # ensure top singular value <= 1 + if G.size(0) > G.size(1): + X = X.T + for _ in range(steps): + A = X @ X.T + B = b * A + c * A @ A + X = a * X + B @ X + if G.size(0) > G.size(1): + X = X.T + return X + + def step(self, closure=None): + """Perform a single optimization step. + + Args: + closure (Callable, optional): A closure that reevaluates the model + and returns the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + lr = group["lr"] + momentum = group['momentum'] + for i, p in enumerate(group['params']): + if self.state[p]['use_muon'] == 1: + g = p.grad + if g is None: + continue + if g.ndim > 2: + g = g.view(g.size(0), -1) + state = self.state[p] + if 'momentum_buffer' not in state: + state['momentum_buffer'] = torch.zeros_like(g) + buf = state['momentum_buffer'] + buf.mul_(momentum).add_(g) + if group['nesterov']: + g = g.add(buf, alpha=momentum) + + meta = None + if isinstance(g, DTensor): + g, meta = self.to_local(g, keep_sharded=False) + # gives NaNs when done with Dtensor, instead of throwing a typical op not supported error, quite sneaky + g = self.zeropower_via_newtonschulz5(g, steps=group['ns_steps']) + if meta is not None: + g = self.to_dist(g, **meta) + g *= max(1, g.size(0)/g.size(1))**0.5 + + g = g.view_as(p.data).type_as(p.data) + p.data.add_(g, alpha=-lr) + else: + # these are all pointwise so we can stay in Dtensor + g = p.grad + if g is None: + continue + state = self.state[p] + if 'step' not in state: + state['step'] = 0 + state['moment1'] = torch.zeros_like(g) + state['moment2'] = torch.zeros_like(g) + state['step'] += 1 + step = state['step'] + buf1 = state['moment1'] + buf2 = state['moment2'] + buf1.lerp_(g, 1-group['adamw_betas'][0]) + buf2.lerp_(g.square(), 1-group['adamw_betas'][1]) + + g = buf1 / (group['adamw_eps'] + buf2.sqrt()) + + bias_correction1 = 1 - group['adamw_betas'][0]**step + bias_correction2 = 1 - group['adamw_betas'][1]**step + scale = bias_correction1 / bias_correction2**0.5 + p.data.mul_(1 - lr * group['adamw_wd']) + p.data.add_(g, alpha=-lr/scale) \ No newline at end of file diff --git a/torchtune/training/lr_schedulers.py b/torchtune/training/lr_schedulers.py index 6f431c9f37..be8d85debf 100644 --- a/torchtune/training/lr_schedulers.py +++ b/torchtune/training/lr_schedulers.py @@ -10,6 +10,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from torchtune.training.memory import OptimizerInBackwardWrapper +from torchtune.modules.optim import Muon def get_cosine_schedule_with_warmup( @@ -88,7 +89,9 @@ def get_lr( ) # LR Schedulers are the same across all param groups for full_finetune right now + lr = param_groups[0]["lr"] + if isinstance(optimizer, Muon): return lr # return Muon learning rate if Muon optimizer for group in param_groups: if group["lr"] != lr: raise RuntimeError("LR Schedulers are different across all param groups ")