|
1 | | -# Orthogonal Subspace Learning with Adaptive OSF |
| 1 | +# Orthogonal Subspace Fine-tuning (OSF) - Continual Learning Example |
2 | 2 |
|
3 | | -## TODO: Runnable Example Needed |
| 3 | +This example demonstrates OSF's ability to learn multiple tasks sequentially while preventing catastrophic forgetting, a key challenge in continual learning. |
4 | 4 |
|
5 | | -This folder is a placeholder for a comprehensive OSF example. As suggested in the review feedback: |
| 5 | +## Introduction |
6 | 6 |
|
7 | | -> "If you can, provide a runnable example in this folder instead, you can take a look at the EVA example for inspiration. A runnable example can be a good place to showcase the different features. Jupyter notebooks are fine as well." |
| 7 | +**Orthogonal Subspace Fine-tuning (OSF)** is a parameter-efficient fine-tuning method designed specifically for continual learning scenarios. Unlike traditional fine-tuning which suffers from catastrophic forgetting when learning new tasks, OSF constrains parameter updates to be orthogonal to previously important directions, effectively preserving knowledge from earlier tasks. |
8 | 8 |
|
9 | | -### Planned Example Features: |
10 | | -- Complete continual learning scenario with multiple tasks |
11 | | -- Demonstration of OSF's catastrophic forgetting prevention |
12 | | -- Configuration examples (target_modules, effective_rank, rank_pattern) |
13 | | -- Performance comparison with baseline methods |
14 | | -- Memory usage analysis |
| 9 | +### Key Features |
15 | 10 |
|
16 | | -### Current Basic Usage: |
17 | | -For basic usage examples and API documentation, see the [OSF documentation](../../docs/source/package_reference/osf.md). |
| 11 | +- **Prevents Catastrophic Forgetting**: Maintains performance on previous tasks while learning new ones |
| 12 | +- **Full Model Capacity**: Unlike LoRA-based methods, OSF allows full-rank updates within the trainable subspace |
| 13 | +- **Progressive Budget Allocation**: Gradually allocates more capacity to preserve previous knowledge |
| 14 | +- **No Additional Parameters**: Modifies weights in-place without adding extra parameters per task |
| 15 | + |
| 16 | +## Quick Start |
| 17 | + |
| 18 | +### Installation |
| 19 | + |
| 20 | +```bash |
| 21 | +pip install -e ".[dev]" |
| 22 | +``` |
| 23 | + |
| 24 | +### Basic Usage |
| 25 | + |
| 26 | +Run the continual learning example with OSF: |
| 27 | + |
| 28 | +```bash |
| 29 | +python osf_continual_learning.py \ |
| 30 | + --model_name meta-llama/Llama-3.1-8B-Instruct \ |
| 31 | + --num_train 1000 \ |
| 32 | + --num_eval 200 \ |
| 33 | + --num_epochs 2 \ |
| 34 | + --output_dir ./outputs |
| 35 | +``` |
| 36 | + |
| 37 | +To compare with full fine-tuning baseline: |
| 38 | + |
| 39 | +```bash |
| 40 | +python osf_continual_learning.py \ |
| 41 | + --model_name meta-llama/Llama-3.1-8B-Instruct \ |
| 42 | + --run_baseline \ |
| 43 | + --output_dir ./outputs |
| 44 | +``` |
| 45 | + |
| 46 | +## Continual Learning Scenario |
| 47 | + |
| 48 | +This example trains a model on three different tasks sequentially: |
| 49 | + |
| 50 | +1. **ScienceQA** - Science question answering across natural, language, and social sciences |
| 51 | +2. **NumGLUE** - Mathematical reasoning and numerical understanding |
| 52 | +3. **FOMC** - Financial sentiment classification (Dovish/Hawkish/Neutral) |
| 53 | + |
| 54 | +### Progressive Capacity Allocation |
| 55 | + |
| 56 | +OSF uses a progressive budget allocation strategy where each task gets decreasing trainable capacity while preserving more knowledge from previous tasks: |
| 57 | + |
| 58 | +| Task | Effective Rank | Preserved | Trainable | Description | |
| 59 | +|------|----------------|-----------|-----------|-------------| |
| 60 | +| Task 1 (ScienceQA) | 0.3 | 30% | 70% | Maximum capacity for first task | |
| 61 | +| Task 2 (NumGLUE) | 0.5 | 50% | 50% | Balanced capacity allocation | |
| 62 | +| Task 3 (FOMC) | 0.7 | 70% | 30% | Minimal capacity, maximum preservation | |
| 63 | + |
| 64 | +This allocation ensures: |
| 65 | +- Early tasks get sufficient capacity to learn effectively |
| 66 | +- Later tasks can still learn new patterns |
| 67 | +- Previous knowledge is progressively protected from interference |
| 68 | + |
| 69 | +## How OSF Works |
| 70 | + |
| 71 | +OSF decomposes each weight matrix using SVD into high-rank (preserved) and low-rank (trainable) components: |
| 72 | + |
| 73 | +``` |
| 74 | +W = U_high @ S_high @ V_high^T + U_low @ S_low @ V_low^T |
| 75 | + └─────────┬─────────┘ └──────┬──────┘ |
| 76 | + frozen trainable |
| 77 | + (previous tasks) (current task) |
| 78 | +``` |
| 79 | + |
| 80 | +During training: |
| 81 | +1. **Initialization**: Perform SVD on each weight matrix |
| 82 | +2. **Partitioning**: Split singular values based on `effective_rank` |
| 83 | +3. **Freezing**: Freeze top-k singular directions (high-rank subspace) |
| 84 | +4. **Training**: Update remaining directions (low-rank subspace) |
| 85 | +5. **Gradient Projection**: Ensure updates are orthogonal to frozen subspace |
| 86 | + |
| 87 | +Between tasks: |
| 88 | +1. **Unload**: Merge OSF components back into base model |
| 89 | +2. **Re-initialize**: Perform fresh SVD with increased `effective_rank` |
| 90 | +3. **Continue**: Train on next task with larger frozen subspace |
| 91 | + |
| 92 | +## Command Line Arguments |
| 93 | + |
| 94 | +``` |
| 95 | +--model_name Model to use (default: meta-llama/Llama-3.1-8B-Instruct) |
| 96 | +--num_train Number of training samples per task (default: 1000) |
| 97 | +--num_eval Number of evaluation samples per task (default: 200) |
| 98 | +--output_dir Directory for outputs (default: ./osf_continual_learning_outputs) |
| 99 | +--num_epochs Training epochs per task (default: 2) |
| 100 | +--learning_rate Learning rate (default: 5e-6) |
| 101 | +--batch_size Batch size per device (default: 32) |
| 102 | +--gradient_accumulation_steps Gradient accumulation (default: 1) |
| 103 | +--max_length Maximum sequence length (default: 512) |
| 104 | +--seed Random seed (default: 42) |
| 105 | +--run_baseline Also run full fine-tuning baseline for comparison |
| 106 | +``` |
| 107 | + |
| 108 | +## Expected Results |
| 109 | + |
| 110 | +### OSF Performance |
| 111 | + |
| 112 | +When using OSF (with 2 epochs per task), you should observe: |
| 113 | +- **Reduced catastrophic forgetting**: Performance on earlier tasks degrades less compared to full fine-tuning |
| 114 | +- **Continued learning**: Model successfully learns each new task |
| 115 | +- **Better retention**: OSF maintains higher average accuracy across all tasks |
| 116 | + |
| 117 | +### Full Fine-tuning Baseline |
| 118 | + |
| 119 | +Standard full fine-tuning typically shows: |
| 120 | +- **Catastrophic forgetting**: Significant performance degradation on earlier tasks |
| 121 | +- **Last task bias**: Model performs well only on the most recent task |
| 122 | +- **Task interference**: New task learning overwrites previous knowledge |
| 123 | + |
| 124 | +## Understanding the Results |
| 125 | + |
| 126 | +### Forgetting Analysis |
| 127 | + |
| 128 | +The script prints a forgetting analysis showing how much earlier task performance changes. |
| 129 | + |
| 130 | +**Example results from training with 2 epochs per task:** |
| 131 | + |
| 132 | +``` |
| 133 | +SUMMARY METRICS |
| 134 | +================================================================================ |
| 135 | +
|
| 136 | +1. Average Accuracy Across All 3 Tasks (After Final Task): |
| 137 | + OSF: 53.42% |
| 138 | + Full FT: 46.26% |
| 139 | + Difference: +7.17% (OSF better) |
| 140 | +
|
| 141 | +2. Average Forgetting (Task 1 & 2): |
| 142 | + Forgetting = Final Accuracy - Initial Accuracy (negative is worse) |
| 143 | +
|
| 144 | + ScienceQA: |
| 145 | + OSF: +30.50% (initial: 55.00% → final: 85.50%) |
| 146 | + Full FT: -13.00% (initial: 84.50% → final: 71.50%) |
| 147 | + Difference: +43.50% (OSF better) |
| 148 | +
|
| 149 | + NumGLUE: |
| 150 | + OSF: +30.00% (initial: 16.00% → final: 46.00%) |
| 151 | + Full FT: +1.00% (initial: 37.50% → final: 38.50%) |
| 152 | + Difference: +29.00% (OSF better) |
| 153 | +
|
| 154 | + Average Forgetting: |
| 155 | + OSF: +30.25% |
| 156 | + Full FT: -6.00% |
| 157 | + Difference: +36.25% (OSF better) |
| 158 | +``` |
| 159 | + |
| 160 | +**Interpreting Forgetting Metrics:** |
| 161 | +- **Negative values** = Forgetting occurred (performance decreased) |
| 162 | +- **Positive values** = Backward transfer occurred (performance improved) |
| 163 | +- **Values closer to 0** = Better retention |
| 164 | + |
| 165 | +In this example, OSF shows significant positive backward transfer (+30.25% average), while Full FT shows slight forgetting (-6.00% average). This demonstrates OSF's ability to not only prevent catastrophic forgetting but also enable beneficial knowledge transfer across tasks. |
| 166 | + |
| 167 | +## Advanced Usage |
| 168 | + |
| 169 | +### Custom Task Configuration |
| 170 | + |
| 171 | +You can modify the tasks and capacity allocation in the script: |
| 172 | + |
| 173 | +```python |
| 174 | +tasks = [ |
| 175 | + { |
| 176 | + "name": "Task1", |
| 177 | + "train": task1_train, |
| 178 | + "eval": task1_eval, |
| 179 | + "effective_rank": 0.2, # Freeze 20%, train 80% |
| 180 | + }, |
| 181 | + { |
| 182 | + "name": "Task2", |
| 183 | + "train": task2_train, |
| 184 | + "eval": task2_eval, |
| 185 | + "effective_rank": 0.6, # Freeze 60%, train 40% |
| 186 | + }, |
| 187 | +] |
| 188 | +``` |
| 189 | + |
| 190 | +### Using Different Models |
| 191 | + |
| 192 | +OSF works with any transformer-based model: |
| 193 | + |
| 194 | +```bash |
| 195 | +# Smaller model for faster experimentation |
| 196 | +python osf_continual_learning.py --model_name gpt2 |
| 197 | + |
| 198 | +# Different LLaMA variant |
| 199 | +python osf_continual_learning.py --model_name meta-llama/Llama-2-7b-hf |
| 200 | +``` |
| 201 | + |
| 202 | +### Adjusting Target Modules |
| 203 | + |
| 204 | +In the script, you can modify which modules to apply OSF to: |
| 205 | + |
| 206 | +```python |
| 207 | +config = OSFConfig( |
| 208 | + target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Attention only |
| 209 | + effective_rank=task["effective_rank"], |
| 210 | +) |
| 211 | +``` |
| 212 | + |
| 213 | +Common configurations: |
| 214 | +- **Attention only**: `["q_proj", "k_proj", "v_proj", "o_proj"]` |
| 215 | +- **Attention + MLP**: `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]` |
| 216 | +- **All linear**: `target_modules="all-linear"` |
| 217 | + |
| 218 | +## Customization |
| 219 | + |
| 220 | +### Adding Your Own Tasks |
| 221 | + |
| 222 | +To add custom tasks, create data loading and formatting functions in `utils.py`: |
18 | 223 |
|
19 | 224 | ```python |
20 | | -import torch |
21 | | -from transformers import AutoModelForCausalLM, AutoTokenizer |
22 | | -from peft import OSFConfig, get_peft_model |
| 225 | +def load_my_task(num_train=1000, num_eval=200, seed=42): |
| 226 | + """Load your custom dataset.""" |
| 227 | + dataset = load_dataset("your/dataset") |
| 228 | + # ... split and return |
| 229 | + return train_dataset, eval_dataset |
| 230 | + |
| 231 | +def format_my_task_for_llama(examples, tokenizer, max_length=512): |
| 232 | + """Format your task for instruction following.""" |
| 233 | + prompts = [] |
| 234 | + labels_text = [] |
23 | 235 |
|
24 | | -model = AutoModelForCausalLM.from_pretrained("gpt2") |
25 | | -config = OSFConfig(target_modules=["c_attn", "c_proj"], effective_rank=8) |
26 | | -model = get_peft_model(model, config) |
| 236 | + for i in range(len(examples)): |
| 237 | + prompt = f"Your instruction template: {examples['input'][i]}" |
| 238 | + label = examples['output'][i] |
27 | 239 |
|
28 | | -optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) |
| 240 | + prompts.append(prompt) |
| 241 | + labels_text.append(label) |
29 | 242 |
|
30 | | -tokenizer = AutoTokenizer.from_pretrained("gpt2") |
31 | | -tokenizer.pad_token = tokenizer.eos_token |
32 | | -inputs = tokenizer("Hello world", return_tensors="pt", padding=True) |
33 | | -loss = model(**inputs, labels=inputs.input_ids).loss |
34 | | -loss.backward() |
35 | | -optimizer.step() |
36 | | -optimizer.zero_grad() |
| 243 | + # ... tokenization logic |
| 244 | + return formatted_examples |
37 | 245 | ``` |
| 246 | + |
| 247 | +Then add to the tasks list in `osf_continual_learning.py`. |
| 248 | + |
| 249 | +## Performance Tips |
| 250 | + |
| 251 | +### Memory Optimization |
| 252 | + |
| 253 | +For large models, consider: |
| 254 | +- Reducing `batch_size` and increasing `gradient_accumulation_steps` |
| 255 | +- Using smaller `max_length` |
| 256 | +- Enabling gradient checkpointing (add to model before OSF): |
| 257 | + ```python |
| 258 | + model.gradient_checkpointing_enable() |
| 259 | + ``` |
| 260 | + |
| 261 | +### Training Speed |
| 262 | + |
| 263 | +To speed up training: |
| 264 | +- Reduce `num_train` and `num_eval` for initial testing |
| 265 | +- Use smaller models (e.g., `gpt2` or `Llama-2-7b`) |
| 266 | +- Reduce `max_length` for shorter sequences |
| 267 | + |
| 268 | +### Better Results |
| 269 | + |
| 270 | +For improved continual learning performance: |
| 271 | +- Play around with `num_epochs` per task (try 2-3 epochs) |
| 272 | +- Adjust `learning_rate` |
| 273 | +- Experiment with different capacity allocation strategies |
| 274 | + |
| 275 | +## Citation |
| 276 | + |
| 277 | +If you use OSF in your research, please cite: |
| 278 | + |
| 279 | +```bibtex |
| 280 | +@misc{nayak2025sculptingsubspacesconstrainedfinetuning, |
| 281 | + title={Sculpting Subspaces: Constrained Full Fine-Tuning in LLMs for Continual Learning}, |
| 282 | + author={Nikhil Shivakumar Nayak and Krishnateja Killamsetty and Ligong Han and Abhishek Bhandwaldar and Prateek Chanda and Kai Xu and Hao Wang and Aldo Pareja and Oleg Silkin and Mustafa Eyceoz and Akash Srivastava}, |
| 283 | + year={2025}, |
| 284 | + eprint={2504.07097}, |
| 285 | + archivePrefix={arXiv}, |
| 286 | + primaryClass={cs.LG}, |
| 287 | + url={https://arxiv.org/abs/2504.07097}, |
| 288 | +} |
| 289 | +``` |
| 290 | + |
| 291 | +## Additional Resources |
| 292 | + |
| 293 | +- [OSF Documentation](../../docs/source/package_reference/osf.md) |
| 294 | +- [PEFT Documentation](https://huggingface.co/docs/peft) |
| 295 | +- [Original Paper](https://arxiv.org/abs/2504.07097) |
| 296 | + |
| 297 | +## License |
| 298 | + |
| 299 | +This example is licensed under Apache 2.0. See the PEFT repository for full license details. |
0 commit comments