Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 114 additions & 0 deletions recipes/configs/qwen2/0.5B_full_single_device_muon.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Config for single device full finetuning in full_finetune_single_device.py
# using a Qwen2 0.5B
#
# This config assumes that you've run the following command before launching
# this run:
# tune download Qwen/Qwen2-0.5B-Instruct --output-dir /tmp/Qwen2-0.5B-Instruct
#
# To launch on a single device, run the following command from root:
# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device
#
# You can add specific overrides through the command line. For example
# to override the checkpointer directory while launching training
# you can run:
# tune run full_finetune_single_device --config qwen2/0.5B_full_single_device checkpointer.checkpoint_dir=<YOUR_CHECKPOINT_DIR>
#
# This config works only for training on single device.

output_dir: /tmp/torchtune/qwen2_0_5B/full_single_device # /tmp may be deleted by your system. Change it to your preference.

# Tokenizer
tokenizer:
_component_: torchtune.models.qwen2.qwen2_tokenizer
path: /tmp/Qwen2-0.5B-Instruct/vocab.json
merges_file: /tmp/Qwen2-0.5B-Instruct/merges.txt
max_seq_len: null

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_cleaned_dataset
packed: False # True increases speed
seed: null
shuffle: False #True

# Model Arguments
model:
_component_: torchtune.models.qwen2.qwen2_0_5b

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Qwen2-0.5B-Instruct
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: QWEN2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 3
epochs: 1
muon:
enabled: True
_component_: torchtune.modules.SingleDeviceMuon
momentum: 0.95
lr: 5e-4 #0.02
weight_decay: 0
optimizer:
_component_: torch.optim.AdamW
fused: True
lr: 2e-5

loss:
_component_: torchtune.modules.loss.LinearCrossEntropyLoss
optimizer_in_bwd: False # True saves memory. Requires gradient_accumulation_steps=1

max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Training environment
device: cuda

# Memory management
enable_activation_checkpointing: True # True reduces memory
enable_activation_offloading: False # True reduces memory

# Reduced precision
dtype: bf16

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True
log_level: INFO # DEBUG, WARN, etc.


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
73 changes: 70 additions & 3 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def setup(self, cfg: DictConfig) -> None:

# _setup_optimizer should take in ckpt_dict only if training is resumed from
# checkpoint. Transforming the opt state dict is handled by this method
self.optimizer = self._setup_optimizer(
cfg_optimizer=cfg.optimizer,
self.optimizer, self.muon = self._setup_optimizer(
cfg=cfg,
opt_state_dict=(
ckpt_dict[training.OPT_KEY] if training.OPT_KEY in ckpt_dict else None
),
Expand Down Expand Up @@ -424,7 +424,8 @@ def _setup_model(

return model

def _setup_optimizer(
# TODO: Remove this function
def _setup_optimizer_delete(
self,
cfg_optimizer: DictConfig,
opt_state_dict: Optional[dict[str, Any]] = None,
Expand All @@ -444,6 +445,66 @@ def _setup_optimizer(
optimizer.load_state_dict(opt_state_dict)
self._logger.info("Optimizer is initialized.")
return optimizer

def _setup_optimizer(self, cfg, opt_state_dict: Optional[dict[str, Any]] = None,) -> Optimizer:
cfg_optimizer = cfg.optimizer
muon_enabled = cfg.muon.pop("enabled")
cfg_muon = cfg.muon

if muon_enabled:
if self.optimizer_in_bwd:
# TODO: Modify optimizer_in_bwd for muon
pass
else:
muon_params = []
non_muon_params = []

for name, module in self._model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
full_name = f"{name}.{param_name}" if name else param_name

if not param.requires_grad:
continue

# Skip if embedding
if isinstance(module, nn.Embedding) or "embed" in full_name.lower():
non_muon_params.append(param)
# Skip if scalar (ndim < 2)
elif param.ndim < 2:
non_muon_params.append(param)
# Skip known head layers
elif "lm_head" in full_name.lower():
non_muon_params.append(param)
else:
muon_params.append(param)

optimizer = config.instantiate(
cfg_optimizer, params=non_muon_params
)
muon = config.instantiate(
cfg_muon, params=muon_params
)
if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
self._logger.info("Optimizer is initialized.")
return optimizer, muon

else:
if self.optimizer_in_bwd:
optimizer_cls = _get_component_from_path(cfg_optimizer.pop("_component_"))
optimizer = OptimizerInBackward(
params=self._model.parameters(),
optimizer_cls=optimizer_cls,
**cfg_optimizer,
)
else:
optimizer = config.instantiate(
cfg_optimizer, params=self._model.parameters()
)
if opt_state_dict:
optimizer.load_state_dict(opt_state_dict)
self._logger.info("Optimizer is initialized.")
return optimizer, None

def _setup_lr_scheduler(
self,
Expand Down Expand Up @@ -555,6 +616,8 @@ def _loss_step(self, batch: dict[str, torch.Tensor]) -> torch.Tensor:

def train(self) -> None:
self.optimizer.zero_grad()
if self.muon:
self.muon.zero_grad()
t0 = time.perf_counter()
running_loss, num_tokens = 0.0, 0
self._profiler.start()
Expand Down Expand Up @@ -599,11 +662,15 @@ def train(self) -> None:
# This will be a no-op for optim in bwd, but prevents a warning w/ LR Scheduler
self.optimizer.step()
self.optimizer.zero_grad(set_to_none=True)
if self.muon:
self.muon.step()
self.muon.zero_grad(set_to_none=True)

if self.lr_scheduler is not None:
self.lr_scheduler.step()

self.global_step += 1
print(f"running_loss: {running_loss} ; num_tokens: {num_tokens}")
loss_value = (
running_loss
/ (num_tokens if not self.optimizer_in_bwd else 1.0)
Expand Down
4 changes: 4 additions & 0 deletions torchtune/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
from .vq_embeddings import VectorQuantizedEmbeddings
from .embedding_utils import resize_token_embeddings # usort: skip

from .muon import Muon, SingleDeviceMuon

__all__ = [
"MultiHeadAttention",
"TanhGate",
Expand Down Expand Up @@ -63,4 +65,6 @@
"classifier_model",
"rms_norm",
"resize_token_embeddings",
"Muon",
"SingleDeviceMuon"
]
Loading