diff --git a/training/arguments.py b/training/arguments.py index 7c05821..a41ba61 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field -from typing import Optional, List +from typing import Optional, List, Union -from transformers import Seq2SeqTrainingArguments +from transformers import Seq2SeqTrainingArguments, TrainingArguments @dataclass @@ -94,6 +94,30 @@ class ModelArguments: "help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left." }, ) + use_lora: bool = field( + default=False, + metadata={"help": "Whether to use LoRA for parameter-efficient fine-tuning"} + ) + lora_rank: int = field( + default=8, + metadata={"help": "Rank of LoRA adaptation matrices"} + ) + lora_alpha: float = field( + default=16.0, + metadata={"help": "LoRA alpha parameter (scaling factor)"} + ) + lora_dropout: float = field( + default=0.05, + metadata={"help": "Dropout probability for LoRA layers"} + ) + lora_target_modules: List[str] = field( + default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2'], + metadata={"help": "Names of modules to apply LoRA to"} + ) + lora_weights_path: Optional[str] = field( + default=None, + metadata={"help": "Path to pretrained LoRA weights to load"} + ) @dataclass @@ -372,4 +396,4 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): codebook_weights: Optional[List[float]] = field( default=None, metadata={"help": "Weights applied to each codebook."}, - ) \ No newline at end of file + ) diff --git a/training/lora.py b/training/lora.py new file mode 100644 index 0000000..88c9d5c --- /dev/null +++ b/training/lora.py @@ -0,0 +1,213 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +class LoRALinear(nn.Module): + """ + LoRA adapted Linear layer using only PyTorch primitives. + """ + def __init__( + self, + in_features: int, + out_features: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0 + ): + super().__init__() + + # Store original layer parameters + if weight is None: + # Initialize with zeros to make it obvious if weights aren't set properly + self.weight = nn.Parameter(torch.zeros(out_features, in_features)) + else: + self.weight = nn.Parameter(weight.clone()) + + if bias is not None: + self.bias = nn.Parameter(bias.clone()) + else: + self.bias = None + + # LoRA specific parameters + self.r = r + self.lora_alpha = lora_alpha + self.scaling = lora_alpha / r + + # LoRA low-rank matrices + # We use kaiming_uniform initialization per original LoRA paper + self.lora_A = nn.Parameter(torch.empty(r, in_features)) + self.lora_B = nn.Parameter(torch.empty(out_features, r)) + self.reset_lora_parameters() + + # Optional dropout + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else nn.Identity() + + # For tracking active status + self.active = True + + def reset_lora_parameters(self): + """Reset LoRA parameters using kaiming uniform initialization.""" + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Main linear operation + result = F.linear(x, self.weight, self.bias) + + # Add the LoRA contribution when active + if self.active: + # Apply dropout to the input + lora_x = self.lora_dropout(x) + + # Low-rank adaptation contribution: B·(A·x)·scaling + lora_result = (lora_x @ self.lora_A.T) @ self.lora_B.T + result += lora_result * self.scaling + + return result + + def set_active(self, active: bool): + """Set whether LoRA adaptation is active.""" + self.active = active + + +class LoRAModuleMixin: + """ + Mixin to add LoRA functionality to a model. + """ + def mark_only_lora_as_trainable(self): + """Freeze all parameters except LoRA parameters.""" + for param in self.parameters(): + param.requires_grad = False + + for name, param in self.named_parameters(): + if "lora_A" in name or "lora_B" in name: + param.requires_grad = True + + def get_lora_state_dict(self) -> Dict[str, torch.Tensor]: + """Get state dict containing only LoRA parameters.""" + lora_state_dict = {} + for name, param in self.named_parameters(): + if "lora_A" in name or "lora_B" in name: + lora_state_dict[name] = param.data.clone() + return lora_state_dict + + def save_lora_weights(self, save_path: Union[str, Path]): + """Save only LoRA weights to disk.""" + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + lora_state_dict = self.get_lora_state_dict() + torch.save(lora_state_dict, save_path) + + def load_lora_weights(self, load_path: Union[str, Path]): + """Load LoRA weights from disk.""" + load_path = Path(load_path) + if not load_path.exists(): + raise ValueError(f"LoRA weights file {load_path} does not exist.") + + # map_location ensure that the LoRA weights are on the same device as the model + lora_state_dict = torch.load(load_path, map_location=next(self.parameters()).device) + + # Load LoRA weights into model + for name, param in self.named_parameters(): + if name in lora_state_dict: + param.data.copy_(lora_state_dict[name]) + + def set_lora_active(self, active: bool): + """Enable or disable LoRA adaptation in the model.""" + for module in self.modules(): + if isinstance(module, LoRALinear): + module.set_active(active) + + +def apply_lora_to_linear_layer( + linear_layer: nn.Linear, + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0 +) -> LoRALinear: + """Replace a linear layer with a LoRA-adapted version.""" + in_features, out_features = linear_layer.in_features, linear_layer.out_features + + # Create new LoRA linear layer with the original weights and biases + lora_layer = LoRALinear( + in_features=in_features, + out_features=out_features, + weight=linear_layer.weight.data, # Pass the actual weights + bias=linear_layer.bias.data if linear_layer.bias is not None else None, # Pass the actual bias + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout + ) + + return lora_layer + + +def apply_lora_to_model( + model: nn.Module, + target_modules: List[str], + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0 +) -> nn.Module: + """ + Apply LoRA to specific modules in a model. + + Args: + model: The model to modify + target_modules: List of module names to apply LoRA to + r: LoRA rank + lora_alpha: LoRA alpha scaling factor + lora_dropout: Dropout probability for LoRA layers + + Returns: + Modified model with LoRA layers + """ + # Apply LoRA mixin to the model + model.__class__ = type( + f"{model.__class__.__name__}WithLoRA", + (model.__class__, LoRAModuleMixin), + {} + ) + + # Replace target modules with LoRA versions + # the list is important to ensure there are no issues when replacing the modules + for name, module in list(model.named_modules()): + if any(target_name in name for target_name in target_modules): + parent_name, child_name = name.rsplit(".", 1) if "." in name else ("", name) + parent = model if parent_name == "" else _get_submodule(model, parent_name) + + if isinstance(module, nn.Linear): + lora_layer = apply_lora_to_linear_layer( + linear_layer=module, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout + ) + setattr(parent, child_name, lora_layer) + + # Set only LoRA parameters as trainable + model.mark_only_lora_as_trainable() + + return model + + +def _get_submodule(model: nn.Module, target: str) -> nn.Module: + """Get a submodule from a model given its path.""" + if target == "": + return model + + atoms = target.split(".") + module = model + + for atom in atoms: + if not hasattr(module, atom): + raise AttributeError(f"Module {module} has no attribute {atom}") + module = getattr(module, atom) + + return module diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 1e368e4..925c893 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -54,6 +54,9 @@ build_delay_pattern_mask, ) +# Import LoRA functionality +from training.lora import apply_lora_to_model + from training.utils import ( get_last_checkpoint, rotate_checkpoints, @@ -339,6 +342,22 @@ def main(): attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"}, ) + # Apply LoRA if enabled + if model_args.use_lora: + logger.info(f"Applying LoRA with rank {model_args.lora_rank} to target modules: {model_args.lora_target_modules}") + model = apply_lora_to_model( + model=model, + target_modules=model_args.lora_target_modules, + r=model_args.lora_rank, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout + ) + # Log trainable parameters after LoRA application + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Total parameters: {total_params}") + logger.info(f"Trainable parameters: {trainable_params} ({trainable_params/total_params*100:.2f}%)") + # enable gradient checkpointing if necessary if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() @@ -1088,6 +1107,15 @@ def generate_step(batch, accelerator): if cur_step == total_train_steps: # un-wrap student model for save unwrapped_model = accelerator.unwrap_model(model) + + # If using LoRA, save LoRA weights separately + if model_args.use_lora and hasattr(unwrapped_model, 'save_lora_weights') and model_args.lora_weights_path is None: + unwrapped_model.save_lora_weights(os.path.join(training_args.output_dir, "final_lora_weights.pt")) + logger.info(f"Final LoRA weights saved to {os.path.join(training_args.output_dir, 'final_lora_weights.pt')}") + elif model_args.use_lora and hasattr(unwrapped_model, 'save_lora_weights') and model_args.lora_weights_path is not None: + unwrapped_model.save_lora_weights(model_args.lora_weights_path) + logger.info(f"Final LoRA weights saved to {model_args.lora_weights_path}") + unwrapped_model.save_pretrained(training_args.output_dir) if training_args.push_to_hub: