Skip to content

Commit 87c73b4

Browse files
Add OSF continual learning example
1 parent 30a19a0 commit 87c73b4

File tree

3 files changed

+1288
-26
lines changed

3 files changed

+1288
-26
lines changed
Lines changed: 288 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,299 @@
1-
# Orthogonal Subspace Learning with Adaptive OSF
1+
# Orthogonal Subspace Fine-tuning (OSF) - Continual Learning Example
22

3-
## TODO: Runnable Example Needed
3+
This example demonstrates OSF's ability to learn multiple tasks sequentially while preventing catastrophic forgetting, a key challenge in continual learning.
44

5-
This folder is a placeholder for a comprehensive OSF example. As suggested in the review feedback:
5+
## Introduction
66

7-
> "If you can, provide a runnable example in this folder instead, you can take a look at the EVA example for inspiration. A runnable example can be a good place to showcase the different features. Jupyter notebooks are fine as well."
7+
**Orthogonal Subspace Fine-tuning (OSF)** is a parameter-efficient fine-tuning method designed specifically for continual learning scenarios. Unlike traditional fine-tuning which suffers from catastrophic forgetting when learning new tasks, OSF constrains parameter updates to be orthogonal to previously important directions, effectively preserving knowledge from earlier tasks.
88

9-
### Planned Example Features:
10-
- Complete continual learning scenario with multiple tasks
11-
- Demonstration of OSF's catastrophic forgetting prevention
12-
- Configuration examples (target_modules, effective_rank, rank_pattern)
13-
- Performance comparison with baseline methods
14-
- Memory usage analysis
9+
### Key Features
1510

16-
### Current Basic Usage:
17-
For basic usage examples and API documentation, see the [OSF documentation](../../docs/source/package_reference/osf.md).
11+
- **Prevents Catastrophic Forgetting**: Maintains performance on previous tasks while learning new ones
12+
- **Full Model Capacity**: Unlike LoRA-based methods, OSF allows full-rank updates within the trainable subspace
13+
- **Progressive Budget Allocation**: Gradually allocates more capacity to preserve previous knowledge
14+
- **No Additional Parameters**: Modifies weights in-place without adding extra parameters per task
15+
16+
## Quick Start
17+
18+
### Installation
19+
20+
```bash
21+
pip install -e ".[dev]"
22+
```
23+
24+
### Basic Usage
25+
26+
Run the continual learning example with OSF:
27+
28+
```bash
29+
python osf_continual_learning.py \
30+
--model_name meta-llama/Llama-3.1-8B-Instruct \
31+
--num_train 1000 \
32+
--num_eval 200 \
33+
--num_epochs 2 \
34+
--output_dir ./outputs
35+
```
36+
37+
To compare with full fine-tuning baseline:
38+
39+
```bash
40+
python osf_continual_learning.py \
41+
--model_name meta-llama/Llama-3.1-8B-Instruct \
42+
--run_baseline \
43+
--output_dir ./outputs
44+
```
45+
46+
## Continual Learning Scenario
47+
48+
This example trains a model on three different tasks sequentially:
49+
50+
1. **ScienceQA** - Science question answering across natural, language, and social sciences
51+
2. **NumGLUE** - Mathematical reasoning and numerical understanding
52+
3. **FOMC** - Financial sentiment classification (Dovish/Hawkish/Neutral)
53+
54+
### Progressive Capacity Allocation
55+
56+
OSF uses a progressive budget allocation strategy where each task gets decreasing trainable capacity while preserving more knowledge from previous tasks:
57+
58+
| Task | Effective Rank | Preserved | Trainable | Description |
59+
|------|----------------|-----------|-----------|-------------|
60+
| Task 1 (ScienceQA) | 0.3 | 30% | 70% | Maximum capacity for first task |
61+
| Task 2 (NumGLUE) | 0.5 | 50% | 50% | Balanced capacity allocation |
62+
| Task 3 (FOMC) | 0.7 | 70% | 30% | Minimal capacity, maximum preservation |
63+
64+
This allocation ensures:
65+
- Early tasks get sufficient capacity to learn effectively
66+
- Later tasks can still learn new patterns
67+
- Previous knowledge is progressively protected from interference
68+
69+
## How OSF Works
70+
71+
OSF decomposes each weight matrix using SVD into high-rank (preserved) and low-rank (trainable) components:
72+
73+
```
74+
W = U_high @ S_high @ V_high^T + U_low @ S_low @ V_low^T
75+
└─────────┬─────────┘ └──────┬──────┘
76+
frozen trainable
77+
(previous tasks) (current task)
78+
```
79+
80+
During training:
81+
1. **Initialization**: Perform SVD on each weight matrix
82+
2. **Partitioning**: Split singular values based on `effective_rank`
83+
3. **Freezing**: Freeze top-k singular directions (high-rank subspace)
84+
4. **Training**: Update remaining directions (low-rank subspace)
85+
5. **Gradient Projection**: Ensure updates are orthogonal to frozen subspace
86+
87+
Between tasks:
88+
1. **Unload**: Merge OSF components back into base model
89+
2. **Re-initialize**: Perform fresh SVD with increased `effective_rank`
90+
3. **Continue**: Train on next task with larger frozen subspace
91+
92+
## Command Line Arguments
93+
94+
```
95+
--model_name Model to use (default: meta-llama/Llama-3.1-8B-Instruct)
96+
--num_train Number of training samples per task (default: 1000)
97+
--num_eval Number of evaluation samples per task (default: 200)
98+
--output_dir Directory for outputs (default: ./osf_continual_learning_outputs)
99+
--num_epochs Training epochs per task (default: 2)
100+
--learning_rate Learning rate (default: 5e-6)
101+
--batch_size Batch size per device (default: 32)
102+
--gradient_accumulation_steps Gradient accumulation (default: 1)
103+
--max_length Maximum sequence length (default: 512)
104+
--seed Random seed (default: 42)
105+
--run_baseline Also run full fine-tuning baseline for comparison
106+
```
107+
108+
## Expected Results
109+
110+
### OSF Performance
111+
112+
When using OSF (with 2 epochs per task), you should observe:
113+
- **Reduced catastrophic forgetting**: Performance on earlier tasks degrades less compared to full fine-tuning
114+
- **Continued learning**: Model successfully learns each new task
115+
- **Better retention**: OSF maintains higher average accuracy across all tasks
116+
117+
### Full Fine-tuning Baseline
118+
119+
Standard full fine-tuning typically shows:
120+
- **Catastrophic forgetting**: Significant performance degradation on earlier tasks
121+
- **Last task bias**: Model performs well only on the most recent task
122+
- **Task interference**: New task learning overwrites previous knowledge
123+
124+
## Understanding the Results
125+
126+
### Forgetting Analysis
127+
128+
The script prints a forgetting analysis showing how much earlier task performance changes.
129+
130+
**Example results from training with 2 epochs per task:**
131+
132+
```
133+
SUMMARY METRICS
134+
================================================================================
135+
136+
1. Average Accuracy Across All 3 Tasks (After Final Task):
137+
OSF: 53.42%
138+
Full FT: 46.26%
139+
Difference: +7.17% (OSF better)
140+
141+
2. Average Forgetting (Task 1 & 2):
142+
Forgetting = Final Accuracy - Initial Accuracy (negative is worse)
143+
144+
ScienceQA:
145+
OSF: +30.50% (initial: 55.00% → final: 85.50%)
146+
Full FT: -13.00% (initial: 84.50% → final: 71.50%)
147+
Difference: +43.50% (OSF better)
148+
149+
NumGLUE:
150+
OSF: +30.00% (initial: 16.00% → final: 46.00%)
151+
Full FT: +1.00% (initial: 37.50% → final: 38.50%)
152+
Difference: +29.00% (OSF better)
153+
154+
Average Forgetting:
155+
OSF: +30.25%
156+
Full FT: -6.00%
157+
Difference: +36.25% (OSF better)
158+
```
159+
160+
**Interpreting Forgetting Metrics:**
161+
- **Negative values** = Forgetting occurred (performance decreased)
162+
- **Positive values** = Backward transfer occurred (performance improved)
163+
- **Values closer to 0** = Better retention
164+
165+
In this example, OSF shows significant positive backward transfer (+30.25% average), while Full FT shows slight forgetting (-6.00% average). This demonstrates OSF's ability to not only prevent catastrophic forgetting but also enable beneficial knowledge transfer across tasks.
166+
167+
## Advanced Usage
168+
169+
### Custom Task Configuration
170+
171+
You can modify the tasks and capacity allocation in the script:
172+
173+
```python
174+
tasks = [
175+
{
176+
"name": "Task1",
177+
"train": task1_train,
178+
"eval": task1_eval,
179+
"effective_rank": 0.2, # Freeze 20%, train 80%
180+
},
181+
{
182+
"name": "Task2",
183+
"train": task2_train,
184+
"eval": task2_eval,
185+
"effective_rank": 0.6, # Freeze 60%, train 40%
186+
},
187+
]
188+
```
189+
190+
### Using Different Models
191+
192+
OSF works with any transformer-based model:
193+
194+
```bash
195+
# Smaller model for faster experimentation
196+
python osf_continual_learning.py --model_name gpt2
197+
198+
# Different LLaMA variant
199+
python osf_continual_learning.py --model_name meta-llama/Llama-2-7b-hf
200+
```
201+
202+
### Adjusting Target Modules
203+
204+
In the script, you can modify which modules to apply OSF to:
205+
206+
```python
207+
config = OSFConfig(
208+
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], # Attention only
209+
effective_rank=task["effective_rank"],
210+
)
211+
```
212+
213+
Common configurations:
214+
- **Attention only**: `["q_proj", "k_proj", "v_proj", "o_proj"]`
215+
- **Attention + MLP**: `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]`
216+
- **All linear**: `target_modules="all-linear"`
217+
218+
## Customization
219+
220+
### Adding Your Own Tasks
221+
222+
To add custom tasks, create data loading and formatting functions in `utils.py`:
18223

19224
```python
20-
import torch
21-
from transformers import AutoModelForCausalLM, AutoTokenizer
22-
from peft import OSFConfig, get_peft_model
225+
def load_my_task(num_train=1000, num_eval=200, seed=42):
226+
"""Load your custom dataset."""
227+
dataset = load_dataset("your/dataset")
228+
# ... split and return
229+
return train_dataset, eval_dataset
230+
231+
def format_my_task_for_llama(examples, tokenizer, max_length=512):
232+
"""Format your task for instruction following."""
233+
prompts = []
234+
labels_text = []
23235

24-
model = AutoModelForCausalLM.from_pretrained("gpt2")
25-
config = OSFConfig(target_modules=["c_attn", "c_proj"], effective_rank=8)
26-
model = get_peft_model(model, config)
236+
for i in range(len(examples)):
237+
prompt = f"Your instruction template: {examples['input'][i]}"
238+
label = examples['output'][i]
27239

28-
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)
240+
prompts.append(prompt)
241+
labels_text.append(label)
29242

30-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
31-
tokenizer.pad_token = tokenizer.eos_token
32-
inputs = tokenizer("Hello world", return_tensors="pt", padding=True)
33-
loss = model(**inputs, labels=inputs.input_ids).loss
34-
loss.backward()
35-
optimizer.step()
36-
optimizer.zero_grad()
243+
# ... tokenization logic
244+
return formatted_examples
37245
```
246+
247+
Then add to the tasks list in `osf_continual_learning.py`.
248+
249+
## Performance Tips
250+
251+
### Memory Optimization
252+
253+
For large models, consider:
254+
- Reducing `batch_size` and increasing `gradient_accumulation_steps`
255+
- Using smaller `max_length`
256+
- Enabling gradient checkpointing (add to model before OSF):
257+
```python
258+
model.gradient_checkpointing_enable()
259+
```
260+
261+
### Training Speed
262+
263+
To speed up training:
264+
- Reduce `num_train` and `num_eval` for initial testing
265+
- Use smaller models (e.g., `gpt2` or `Llama-2-7b`)
266+
- Reduce `max_length` for shorter sequences
267+
268+
### Better Results
269+
270+
For improved continual learning performance:
271+
- Play around with `num_epochs` per task (try 2-3 epochs)
272+
- Adjust `learning_rate`
273+
- Experiment with different capacity allocation strategies
274+
275+
## Citation
276+
277+
If you use OSF in your research, please cite:
278+
279+
```bibtex
280+
@misc{nayak2025sculptingsubspacesconstrainedfinetuning,
281+
title={Sculpting Subspaces: Constrained Full Fine-Tuning in LLMs for Continual Learning},
282+
author={Nikhil Shivakumar Nayak and Krishnateja Killamsetty and Ligong Han and Abhishek Bhandwaldar and Prateek Chanda and Kai Xu and Hao Wang and Aldo Pareja and Oleg Silkin and Mustafa Eyceoz and Akash Srivastava},
283+
year={2025},
284+
eprint={2504.07097},
285+
archivePrefix={arXiv},
286+
primaryClass={cs.LG},
287+
url={https://arxiv.org/abs/2504.07097},
288+
}
289+
```
290+
291+
## Additional Resources
292+
293+
- [OSF Documentation](../../docs/source/package_reference/osf.md)
294+
- [PEFT Documentation](https://huggingface.co/docs/peft)
295+
- [Original Paper](https://arxiv.org/abs/2504.07097)
296+
297+
## License
298+
299+
This example is licensed under Apache 2.0. See the PEFT repository for full license details.

0 commit comments

Comments
 (0)