From 93ad8079316ecd24d11d22968952a6227da66607 Mon Sep 17 00:00:00 2001 From: aniruddh10124 <88266561+aniruddh10124@users.noreply.github.com> Date: Fri, 16 May 2025 21:43:26 +0530 Subject: [PATCH 1/4] Create lora.py It is the base file which implements the LoRA Linear Layer and implements useful methods for working with a model on which LoRA has been applied. Furthermore, implements a function which can convert any model into one with LoRA if the names of the affected linear layers of the model are known --- training/lora.py | 213 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 213 insertions(+) create mode 100644 training/lora.py diff --git a/training/lora.py b/training/lora.py new file mode 100644 index 0000000..88c9d5c --- /dev/null +++ b/training/lora.py @@ -0,0 +1,213 @@ +import math +import torch +import torch.nn as nn +import torch.nn.functional as F +from pathlib import Path +from typing import Dict, List, Optional, Tuple, Union + +class LoRALinear(nn.Module): + """ + LoRA adapted Linear layer using only PyTorch primitives. + """ + def __init__( + self, + in_features: int, + out_features: int, + weight: Optional[torch.Tensor] = None, + bias: Optional[torch.Tensor] = None, + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0 + ): + super().__init__() + + # Store original layer parameters + if weight is None: + # Initialize with zeros to make it obvious if weights aren't set properly + self.weight = nn.Parameter(torch.zeros(out_features, in_features)) + else: + self.weight = nn.Parameter(weight.clone()) + + if bias is not None: + self.bias = nn.Parameter(bias.clone()) + else: + self.bias = None + + # LoRA specific parameters + self.r = r + self.lora_alpha = lora_alpha + self.scaling = lora_alpha / r + + # LoRA low-rank matrices + # We use kaiming_uniform initialization per original LoRA paper + self.lora_A = nn.Parameter(torch.empty(r, in_features)) + self.lora_B = nn.Parameter(torch.empty(out_features, r)) + self.reset_lora_parameters() + + # Optional dropout + self.lora_dropout = nn.Dropout(p=lora_dropout) if lora_dropout > 0 else nn.Identity() + + # For tracking active status + self.active = True + + def reset_lora_parameters(self): + """Reset LoRA parameters using kaiming uniform initialization.""" + nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5)) + nn.init.zeros_(self.lora_B) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Main linear operation + result = F.linear(x, self.weight, self.bias) + + # Add the LoRA contribution when active + if self.active: + # Apply dropout to the input + lora_x = self.lora_dropout(x) + + # Low-rank adaptation contribution: B·(A·x)·scaling + lora_result = (lora_x @ self.lora_A.T) @ self.lora_B.T + result += lora_result * self.scaling + + return result + + def set_active(self, active: bool): + """Set whether LoRA adaptation is active.""" + self.active = active + + +class LoRAModuleMixin: + """ + Mixin to add LoRA functionality to a model. + """ + def mark_only_lora_as_trainable(self): + """Freeze all parameters except LoRA parameters.""" + for param in self.parameters(): + param.requires_grad = False + + for name, param in self.named_parameters(): + if "lora_A" in name or "lora_B" in name: + param.requires_grad = True + + def get_lora_state_dict(self) -> Dict[str, torch.Tensor]: + """Get state dict containing only LoRA parameters.""" + lora_state_dict = {} + for name, param in self.named_parameters(): + if "lora_A" in name or "lora_B" in name: + lora_state_dict[name] = param.data.clone() + return lora_state_dict + + def save_lora_weights(self, save_path: Union[str, Path]): + """Save only LoRA weights to disk.""" + save_path = Path(save_path) + save_path.parent.mkdir(parents=True, exist_ok=True) + + lora_state_dict = self.get_lora_state_dict() + torch.save(lora_state_dict, save_path) + + def load_lora_weights(self, load_path: Union[str, Path]): + """Load LoRA weights from disk.""" + load_path = Path(load_path) + if not load_path.exists(): + raise ValueError(f"LoRA weights file {load_path} does not exist.") + + # map_location ensure that the LoRA weights are on the same device as the model + lora_state_dict = torch.load(load_path, map_location=next(self.parameters()).device) + + # Load LoRA weights into model + for name, param in self.named_parameters(): + if name in lora_state_dict: + param.data.copy_(lora_state_dict[name]) + + def set_lora_active(self, active: bool): + """Enable or disable LoRA adaptation in the model.""" + for module in self.modules(): + if isinstance(module, LoRALinear): + module.set_active(active) + + +def apply_lora_to_linear_layer( + linear_layer: nn.Linear, + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0 +) -> LoRALinear: + """Replace a linear layer with a LoRA-adapted version.""" + in_features, out_features = linear_layer.in_features, linear_layer.out_features + + # Create new LoRA linear layer with the original weights and biases + lora_layer = LoRALinear( + in_features=in_features, + out_features=out_features, + weight=linear_layer.weight.data, # Pass the actual weights + bias=linear_layer.bias.data if linear_layer.bias is not None else None, # Pass the actual bias + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout + ) + + return lora_layer + + +def apply_lora_to_model( + model: nn.Module, + target_modules: List[str], + r: int = 8, + lora_alpha: float = 16, + lora_dropout: float = 0.0 +) -> nn.Module: + """ + Apply LoRA to specific modules in a model. + + Args: + model: The model to modify + target_modules: List of module names to apply LoRA to + r: LoRA rank + lora_alpha: LoRA alpha scaling factor + lora_dropout: Dropout probability for LoRA layers + + Returns: + Modified model with LoRA layers + """ + # Apply LoRA mixin to the model + model.__class__ = type( + f"{model.__class__.__name__}WithLoRA", + (model.__class__, LoRAModuleMixin), + {} + ) + + # Replace target modules with LoRA versions + # the list is important to ensure there are no issues when replacing the modules + for name, module in list(model.named_modules()): + if any(target_name in name for target_name in target_modules): + parent_name, child_name = name.rsplit(".", 1) if "." in name else ("", name) + parent = model if parent_name == "" else _get_submodule(model, parent_name) + + if isinstance(module, nn.Linear): + lora_layer = apply_lora_to_linear_layer( + linear_layer=module, + r=r, + lora_alpha=lora_alpha, + lora_dropout=lora_dropout + ) + setattr(parent, child_name, lora_layer) + + # Set only LoRA parameters as trainable + model.mark_only_lora_as_trainable() + + return model + + +def _get_submodule(model: nn.Module, target: str) -> nn.Module: + """Get a submodule from a model given its path.""" + if target == "": + return model + + atoms = target.split(".") + module = model + + for atom in atoms: + if not hasattr(module, atom): + raise AttributeError(f"Module {module} has no attribute {atom}") + module = getattr(module, atom) + + return module From 3dbdfd204c47db6d24b8858ed6322897399e4d7e Mon Sep 17 00:00:00 2001 From: aniruddh10124 <88266561+aniruddh10124@users.noreply.github.com> Date: Fri, 16 May 2025 21:45:02 +0530 Subject: [PATCH 2/4] Update arguments.py Added arguments for a flag for applying lora and added functionality to change the parameters used in lora like rank, alpha, dropout probability, target module names (the module names which will be changed to lora linear from linear) and lora weights path (the path to the lora weights) --- training/arguments.py | 30 +++++++++++++++++++++++++++--- 1 file changed, 27 insertions(+), 3 deletions(-) diff --git a/training/arguments.py b/training/arguments.py index 7c05821..a41ba61 100644 --- a/training/arguments.py +++ b/training/arguments.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field -from typing import Optional, List +from typing import Optional, List, Union -from transformers import Seq2SeqTrainingArguments +from transformers import Seq2SeqTrainingArguments, TrainingArguments @dataclass @@ -94,6 +94,30 @@ class ModelArguments: "help": "Prompt tokenizer padding side. Defaults to `left`. If the prompt is pre-pended to the codebooks hidden states, it should be padded on the left." }, ) + use_lora: bool = field( + default=False, + metadata={"help": "Whether to use LoRA for parameter-efficient fine-tuning"} + ) + lora_rank: int = field( + default=8, + metadata={"help": "Rank of LoRA adaptation matrices"} + ) + lora_alpha: float = field( + default=16.0, + metadata={"help": "LoRA alpha parameter (scaling factor)"} + ) + lora_dropout: float = field( + default=0.05, + metadata={"help": "Dropout probability for LoRA layers"} + ) + lora_target_modules: List[str] = field( + default_factory=lambda: ['q_proj', 'k_proj', 'v_proj', 'out_proj', 'fc1', 'fc2'], + metadata={"help": "Names of modules to apply LoRA to"} + ) + lora_weights_path: Optional[str] = field( + default=None, + metadata={"help": "Path to pretrained LoRA weights to load"} + ) @dataclass @@ -372,4 +396,4 @@ class ParlerTTSTrainingArguments(Seq2SeqTrainingArguments): codebook_weights: Optional[List[float]] = field( default=None, metadata={"help": "Weights applied to each codebook."}, - ) \ No newline at end of file + ) From 5e87dfc9e0adbedff0830b23360e2846e4850e3b Mon Sep 17 00:00:00 2001 From: aniruddh10124 <88266561+aniruddh10124@users.noreply.github.com> Date: Fri, 16 May 2025 21:47:00 +0530 Subject: [PATCH 3/4] Update run_parler_tts_training.py Changed the training script to allow LoRA implementation. Three main areas of changes * Imported apply_lora_to_model * Applied the above function to the model if the user_lora flag is True * Saved the LoRA weights at the checkpoints when the other model weights are being saved --- training/run_parler_tts_training.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 1e368e4..12e2667 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -54,6 +54,9 @@ build_delay_pattern_mask, ) +# Import LoRA functionality +from training.lora import apply_lora_to_model + from training.utils import ( get_last_checkpoint, rotate_checkpoints, @@ -339,6 +342,22 @@ def main(): attn_implementation={"decoder": model_args.attn_implementation, "text_encoder": "eager"}, ) + # Apply LoRA if enabled + if model_args.use_lora: + logger.info(f"Applying LoRA with rank {model_args.lora_rank} to target modules: {model_args.lora_target_modules}") + model = apply_lora_to_model( + model=model, + target_modules=model_args.lora_target_modules, + r=model_args.lora_rank, + lora_alpha=model_args.lora_alpha, + lora_dropout=model_args.lora_dropout + ) + # Log trainable parameters after LoRA application + total_params = sum(p.numel() for p in model.parameters()) + trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"Total parameters: {total_params}") + logger.info(f"Trainable parameters: {trainable_params} ({trainable_params/total_params*100:.2f}%)") + # enable gradient checkpointing if necessary if training_args.gradient_checkpointing: model.gradient_checkpointing_enable() @@ -1088,6 +1107,12 @@ def generate_step(batch, accelerator): if cur_step == total_train_steps: # un-wrap student model for save unwrapped_model = accelerator.unwrap_model(model) + + # If using LoRA, save LoRA weights separately + if model_args.use_lora and hasattr(unwrapped_model, 'save_lora_weights'): + unwrapped_model.save_lora_weights(os.path.join(training_args.output_dir, "final_lora_weights.pt")) + logger.info(f"Final LoRA weights saved to {os.path.join(training_args.output_dir, 'final_lora_weights.pt')}") + unwrapped_model.save_pretrained(training_args.output_dir) if training_args.push_to_hub: From 48dce6914e6512667330c106c84e24b6fc7b673a Mon Sep 17 00:00:00 2001 From: aniruddh10124 <88266561+aniruddh10124@users.noreply.github.com> Date: Fri, 16 May 2025 22:59:50 +0530 Subject: [PATCH 4/4] Update run_parler_tts_training.py Edited training script so that LoRA can be applied to the model --- training/run_parler_tts_training.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/training/run_parler_tts_training.py b/training/run_parler_tts_training.py index 12e2667..925c893 100644 --- a/training/run_parler_tts_training.py +++ b/training/run_parler_tts_training.py @@ -1109,9 +1109,12 @@ def generate_step(batch, accelerator): unwrapped_model = accelerator.unwrap_model(model) # If using LoRA, save LoRA weights separately - if model_args.use_lora and hasattr(unwrapped_model, 'save_lora_weights'): + if model_args.use_lora and hasattr(unwrapped_model, 'save_lora_weights') and model_args.lora_weights_path is None: unwrapped_model.save_lora_weights(os.path.join(training_args.output_dir, "final_lora_weights.pt")) logger.info(f"Final LoRA weights saved to {os.path.join(training_args.output_dir, 'final_lora_weights.pt')}") + elif model_args.use_lora and hasattr(unwrapped_model, 'save_lora_weights') and model_args.lora_weights_path is not None: + unwrapped_model.save_lora_weights(model_args.lora_weights_path) + logger.info(f"Final LoRA weights saved to {model_args.lora_weights_path}") unwrapped_model.save_pretrained(training_args.output_dir)