-
Couldn't load subscription status.
- Fork 2.1k
FEAT add GraLoRA #2851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
FEAT add GraLoRA #2851
Changes from 8 commits
6dfa24e
bfa1ef7
9813b17
c1fe6c4
4f1444f
9431502
dec25f5
925ad72
3f69d8f
430e896
351877f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,32 @@ | ||
| # GraLoRA | ||
|
|
||
| [**Granular Low-Rank Adaptation (GraLoRA)**](https://huggingface.co/papers/2505.20355) is a PEFT method designed to enhance the **expressivity** of low-rank adaptation while improving **robustness to outlier** activations, based on insights from well-known issues in quantization. | ||
|
|
||
|  | ||
|
|
||
| Unlike standard LoRA, which applies a single low-rank adapter across the entire feature space, GraLoRA introduces a structured and fine-grained adaptation scheme. It divides the adaptation space into a grid of $𝑘^2$ smaller, independent adapter pairs, each responsible for a localized subset of the input and output dimensions. As a result, each adapter operates on a subspace that is $k$ times smaller in both dimensions than the original LoRA adapter. | ||
|
|
||
| This granular decomposition enables spatially localized and context-aware updates, effectively increasing representational capacity without additional parameters or computational cost. By isolating the influence of extreme activations within smaller subspaces, GraLoRA mitigates gradient distortion and preserves inter-channel balance during adaptation. | ||
|
|
||
| --- | ||
|
|
||
| The abstract from the paper is: | ||
|
|
||
| *Low-Rank Adaptation (LoRA) is a popular method for parameter-efficient fine- | ||
| tuning (PEFT) of generative models, valued for its simplicity and effectiveness. | ||
| Despite recent enhancements, LoRA still suffers from a fundamental limitation: | ||
| overfitting when the bottleneck is widened. It performs best at ranks 32–64, yet its | ||
| accuracy stagnates or declines at higher ranks, still falling short of full fine-tuning | ||
| (FFT) performance. We identify the root cause as LoRA’s structural bottleneck, | ||
| which introduces gradient entanglement to the unrelated input channels and distorts | ||
| gradient propagation. To address this, we introduce a novel structure, Granular | ||
| Low-Rank Adaptation (GraLoRA) that partitions weight matrices into sub-blocks, | ||
| each with its own low-rank adapter. With negligible computational or storage cost, | ||
| GraLoRA overcomes LoRA’s limitations, effectively increases the representational | ||
| capacity, and more closely approximates FFT behavior. Experiments on code | ||
| generation, commonsense reasoning, mathematical reasoning, general language | ||
| understanding, and image generation benchmarks show that GraLoRA consistently | ||
| outperforms LoRA and other baselines, achieving up to +8.5% absolute gain in | ||
| Pass@1 on HumanEval+. These improvements hold across model sizes and rank | ||
| settings, making GraLoRA a scalable and robust solution for PEFT.* | ||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,71 @@ | ||||||
| # GraLoRA: Granular Low-Rank Adaptation | ||||||
|
|
||||||
|  | ||||||
|
|
||||||
| ## Introduction | ||||||
| [**Granular Low-Rank Adaptation (GraLoRA)**](https://huggingface.co/papers/2505.20355) is a PEFT method designed to enhance the **expressivity** of low-rank adaptation while improving **robustness to outlier** activations, based on insights from well-known issues in quantization. | ||||||
|
|
||||||
| GraLoRA introduces a structured and fine-grained adaptation scheme. It divides the adaptation space into a grid of $𝑘^2$ smaller, independent adapter pairs, each responsible for a localized subset of the input and output dimensions. | ||||||
|
|
||||||
| ## Quick start | ||||||
|
|
||||||
| With respect to your standard PEFT training procedure with LoRA, simply swap your `LoraConfig` for a `GraloraConfig`. | ||||||
|
|
||||||
| ```python | ||||||
| import torch | ||||||
| from peft import GraloraConfig, get_peft_model | ||||||
| from transformers import AutoTokenizer, AutoModelForCausalLM, Trainer | ||||||
| from datasets import load_dataset | ||||||
|
|
||||||
| model = AutoModelForCausalLM.from_pretrained("huggyllama/llama-7b", device_map="auto") | ||||||
| tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b") | ||||||
| dataset = load_dataset("timdettmers/openassistant-guanaco", split="train") | ||||||
| gralora_config = GraloraConfig() | ||||||
| peft_model = get_peft_model(model, gralora_config) | ||||||
| trainer = transformers.Trainer( | ||||||
| model=peft_model, | ||||||
| train_dataset=dataset, | ||||||
| dataset_text_field="text", | ||||||
| max_seq_length=2048, | ||||||
| tokenizer=tokenizer, | ||||||
| ) | ||||||
|
Comment on lines
+25
to
+31
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The example didn't run for me because of unknown arguments. I know you copied an existing one, but either that one also didn't work, or it's because of some recent changes in #peft_model = get_peft_model(model, gralora_config) <= remove this
trainer = SFTTrainer(
model=model,
train_dataset=dataset,
processing_class=tokenizer,
peft_config=gralora_config,
args=SFTConfig(
max_length=2048,
dataset_text_field="text",
per_device_train_batch_size=2,
),
)
trainer.train()
trainer.model.save_pretrained("gralora-llama-3-8b")SFTTrainer and SFTConfig need to imported from trl. Alternatively, you could also rewrite it to use |
||||||
| trainer.train() | ||||||
| peft_model.save_pretrained("gralora-llama-3-8b") | ||||||
| ``` | ||||||
|
|
||||||
| Run the finetuning script simply by running: | ||||||
| ```python | ||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| python examples/gralora_finetuning/gralora_finetuning.py --base_model meta-llama/Meta-Llama-3-8B --data_path timdettmers/openassistant-guanaco | ||||||
| ``` | ||||||
|
|
||||||
| ## Use the model on 🤗 | ||||||
| You can load and use the model as any other 🤗 models. | ||||||
| ```python | ||||||
| import torch | ||||||
| from peft import PeftModel | ||||||
| from transformers import AutoModelForCausalLM | ||||||
|
|
||||||
| model = AutoModelForCausalLM.from_pretrained( | ||||||
| "meta-llama/Meta-Llama-3-8B", dtype=torch.bfloat16, device_map="auto" | ||||||
| ) | ||||||
| peft_model = PeftModel.from_pretrained(model, "gralora-llama-3-8b") | ||||||
| ``` | ||||||
|
|
||||||
| ## Additonal Notes | ||||||
| While `gralora_k` is set to 2 for default, you can increase this value to create more fine-grained adapters. `gralora_k` of 4 is recommended when the total rank (`r + hybrid_r`) is 64 or higher. | ||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
Comment on lines
+57
to
+59
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Delete 3 empty lines |
||||||
| ## Citation | ||||||
| ``` | ||||||
| @misc{jung2025graloragranularlowrankadaptation, | ||||||
| title={GraLoRA: Granular Low-Rank Adaptation for Parameter-Efficient Fine-Tuning}, | ||||||
| author={Yeonjoon Jung and Daehyun Ahn and Hyungjun Kim and Taesu Kim and Eunhyeok Park}, | ||||||
| year={2025}, | ||||||
| eprint={2505.20355}, | ||||||
| archivePrefix={arXiv}, | ||||||
| primaryClass={cs.LG}, | ||||||
| url={https://arxiv.org/abs/2505.20355}, | ||||||
| } | ||||||
| ``` | ||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,213 @@ | ||
| # This script is based on examples/dora_finetuning/dora_finetuning.py | ||
| import os | ||
|
|
||
| import torch | ||
| from datasets import load_dataset | ||
| from transformers import ( | ||
| AutoModelForCausalLM, | ||
| AutoTokenizer, | ||
| BitsAndBytesConfig, | ||
| DataCollatorForLanguageModeling, | ||
| Trainer, | ||
| TrainingArguments, | ||
| ) | ||
|
|
||
| from peft import GraloraConfig, get_peft_model, prepare_model_for_kbit_training | ||
|
|
||
|
|
||
| def train_model( | ||
| base_model: str, | ||
| data_path: str, | ||
| output_dir: str, | ||
| batch_size: int, | ||
| num_epochs: int, | ||
| learning_rate: float, | ||
| cutoff_len: int, | ||
| val_set_size: int, | ||
| quantize: bool, | ||
| eval_step: int, | ||
| save_step: int, | ||
| device: str, | ||
| gralora_r: int, | ||
| gralora_alpha: int, | ||
| gralora_dropout: float, | ||
| gralora_target_modules: str, | ||
| gralora_k: int, | ||
| hybrid_r: int, | ||
| hub_model_id: str, | ||
| push_to_hub: bool, | ||
| ): | ||
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | ||
| hf_token = os.getenv("HF_TOKEN") | ||
|
|
||
| # Setup device | ||
| if device == "auto": | ||
| device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda" | ||
| else: | ||
| device = torch.device(device) | ||
| print(f"Using device: {device}") | ||
|
|
||
| # load tokenizer | ||
| tokenizer = AutoTokenizer.from_pretrained(base_model, token=hf_token) | ||
|
|
||
| # Quantized GraLoRA: IF YOU WANNA QUANTIZE THE MODEL | ||
| if quantize: | ||
| if (torch.cuda.is_available() and torch.cuda.is_bf16_supported()) or torch.xpu.is_available(): | ||
| bnb_4bit_compute_dtype = torch.bfloat16 | ||
| else: | ||
| bnb_4bit_compute_dtype = torch.float16 | ||
| model = AutoModelForCausalLM.from_pretrained( | ||
| base_model, | ||
| token=hf_token, | ||
| quantization_config=BitsAndBytesConfig( | ||
| load_in_4bit=True, | ||
| bnb_4bit_compute_dtype=bnb_4bit_compute_dtype, | ||
| bnb_4bit_use_double_quant=True, | ||
| bnb_4bit_quant_type="nf4", | ||
| ), | ||
| ) | ||
| # setup for quantized training | ||
| model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=True) | ||
| else: | ||
| model = AutoModelForCausalLM.from_pretrained(base_model, token=hf_token) | ||
| # GraLoRA config for the PEFT model | ||
| gralora_config = GraloraConfig( | ||
| r=gralora_r, # Rank of matrix | ||
| gralora_alpha=gralora_alpha, | ||
| target_modules=( | ||
| gralora_target_modules.split(",") | ||
| if gralora_target_modules | ||
| else ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] | ||
| ), | ||
| gralora_dropout=gralora_dropout, | ||
| gralora_k=gralora_k, | ||
| hybrid_r=hybrid_r, | ||
| bias="none", | ||
| ) | ||
|
|
||
| # get the peft model with GraLoRA config | ||
| model = get_peft_model(model, gralora_config) | ||
|
|
||
| model.to(device) # MODEL TO GPU/CUDA | ||
| tokenizer.pad_token = tokenizer.eos_token | ||
|
|
||
| # Load the dataset | ||
| dataset = load_dataset(data_path) | ||
|
|
||
| def tokenize_function(examples): | ||
| inputs = tokenizer(examples["text"], padding="max_length", truncation=True, max_length=cutoff_len) | ||
| inputs["labels"] = inputs["input_ids"].copy() # setting labels for a language modeling task | ||
| return inputs | ||
|
|
||
| # Tokenize the dataset and prepare for training | ||
| tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names) | ||
|
|
||
| # Data collator to dynamically pad the batched examples | ||
| data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) | ||
|
|
||
| # Define training arguments | ||
| training_args = TrainingArguments( | ||
| output_dir=output_dir, | ||
| num_train_epochs=num_epochs, | ||
| per_device_train_batch_size=batch_size, | ||
| per_device_eval_batch_size=batch_size, | ||
| warmup_steps=100, | ||
| weight_decay=0.01, | ||
| logging_dir="./logs", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This argument errors for me, does it work for you? |
||
| logging_steps=eval_step, | ||
| save_steps=save_step, | ||
| save_total_limit=2, | ||
| push_to_hub=push_to_hub, | ||
| hub_model_id=hub_model_id, | ||
| gradient_accumulation_steps=16, | ||
| fp16=True, | ||
| learning_rate=learning_rate, | ||
| hub_token=hf_token, | ||
| ) | ||
|
|
||
| # Clear device cache to free memory | ||
| if torch.cuda.is_available(): | ||
| torch.cuda.empty_cache() | ||
| elif torch.xpu.is_available(): | ||
| torch.xpu.empty_cache() | ||
|
|
||
| # Initialize the Trainer | ||
| trainer = Trainer( | ||
| model=model, | ||
| args=training_args, | ||
| train_dataset=tokenized_datasets["train"], | ||
| eval_dataset=tokenized_datasets["test"], | ||
| data_collator=data_collator, | ||
| ) | ||
|
|
||
| # Start model training | ||
| trainer.train() | ||
|
|
||
| # Save and push the trained model and tokenizer | ||
| if push_to_hub: | ||
| # Push the main model to the hub | ||
| trainer.push_to_hub(commit_message="Fine-tuned model") | ||
|
|
||
| # Save the model and tokenizer locally | ||
| model.save_pretrained(output_dir) | ||
| tokenizer.save_pretrained(output_dir) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import argparse | ||
|
|
||
| parser = argparse.ArgumentParser(description="Fine-tune LLaMA with GraLoRA and PEFT") | ||
| parser.add_argument("--base_model", type=str, default="huggyllama/llama-7b", help="Base model path or name") | ||
| parser.add_argument( | ||
| "--data_path", type=str, default="timdettmers/openassistant-guanaco", help="Dataset path or name" | ||
| ) | ||
| parser.add_argument( | ||
| "--output_dir", type=str, default="path/to/output", help="Output directory for the fine-tuned model" | ||
| ) | ||
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size") | ||
| parser.add_argument("--num_epochs", type=int, default=1, help="Number of training epochs") | ||
| parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate") | ||
| parser.add_argument("--cutoff_len", type=int, default=512, help="Cutoff length for tokenization") | ||
| parser.add_argument("--val_set_size", type=int, default=500, help="Validation set size") | ||
| parser.add_argument("--quantize", action="store_true", help="Use quantization") | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quantizing wouldn't work, right? |
||
| parser.add_argument("--eval_step", type=int, default=10, help="Evaluation step interval") | ||
| parser.add_argument("--save_step", type=int, default=100, help="Save step interval") | ||
| parser.add_argument("--device", type=str, default="auto", help="Device to use for training") | ||
| parser.add_argument("--gralora_r", type=int, default=8, help="LoRA rank") | ||
| parser.add_argument("--gralora_alpha", type=int, default=16, help="LoRA alpha") | ||
| parser.add_argument("--gralora_dropout", type=float, default=0.05, help="LoRA dropout rate") | ||
| parser.add_argument( | ||
| "--gralora_target_modules", type=str, default=None, help="Comma-separated list of target modules for LoRA" | ||
| ) | ||
| parser.add_argument("--gralora_k", type=int, default=2, help="GraLoRA k") | ||
| parser.add_argument("--hybrid_r", type=int, default=0, help="Hybrid rank") | ||
| parser.add_argument( | ||
| "--hub_model_id", | ||
| type=str, | ||
| default="path/to/repo", | ||
| help="Repository name to push the model on the Hugging Face Hub", | ||
| ) | ||
| parser.add_argument("--push_to_hub", action="store_true", help="Whether to push the model to Hugging Face Hub") | ||
| args = parser.parse_args() | ||
| train_model( | ||
| base_model=args.base_model, | ||
| data_path=args.data_path, | ||
| output_dir=args.output_dir, | ||
| batch_size=args.batch_size, | ||
| num_epochs=args.num_epochs, | ||
| learning_rate=args.learning_rate, | ||
| cutoff_len=args.cutoff_len, | ||
| val_set_size=args.val_set_size, | ||
| quantize=args.quantize, | ||
| eval_step=args.eval_step, | ||
| save_step=args.save_step, | ||
| device=args.device, | ||
| gralora_r=args.gralora_r, | ||
| gralora_alpha=args.gralora_alpha, | ||
| gralora_dropout=args.gralora_dropout, | ||
| gralora_target_modules=args.gralora_target_modules, | ||
| gralora_k=args.gralora_k, | ||
| hybrid_r=args.hybrid_r, | ||
| hub_model_id=args.hub_model_id, | ||
| push_to_hub=args.push_to_hub, | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # Copyright 2025-present the HuggingFace Inc. team. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| from peft.utils import register_peft_method | ||
|
|
||
| from .config import GraloraConfig | ||
| from .layer import GraloraLayer | ||
| from .model import GraloraModel | ||
|
|
||
|
|
||
| __all__ = ["GraloraConfig", "GraloraLayer", "GraloraModel"] | ||
|
|
||
| register_peft_method(name="gralora", config_cls=GraloraConfig, model_cls=GraloraModel) |
Uh oh!
There was an error while loading. Please reload this page.