Skip to content

Commit 958a60f

Browse files
rebasing to make use of simplified basetuner implementation and adding more experiment results
1 parent 00073fe commit 958a60f

File tree

9 files changed

+450
-197
lines changed

9 files changed

+450
-197
lines changed

docs/source/package_reference/osf.md

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -88,16 +88,16 @@ config = OSFConfig() # Uses model-appropriate defaults
8888

8989
### Effective Rank Configuration
9090

91-
Control the decomposition rank:
91+
Control the preserved/trainable subspaces:
9292

9393
```python
94-
# Global rank (applies to all target modules)
95-
config = OSFConfig(effective_rank=16)
94+
# Global preserved rank (applies to all target modules)
95+
config = OSFConfig(effective_rank=16) # preserves top-16 singular directions; trains the rest
9696

97-
# Automatic rank (50% of the smaller matrix dimension per target)
97+
# Automatic preserved rank (50% of the smaller matrix dimension per target)
9898
config = OSFConfig(effective_rank=None)
9999

100-
# Per-module rank overrides
100+
# Per-module preserved-rank overrides
101101
config = OSFConfig(
102102
effective_rank=8,
103103
rank_pattern={
@@ -107,6 +107,9 @@ config = OSFConfig(
107107
)
108108
```
109109

110+
Note: OSF's `effective_rank` is the preserved (frozen) rank, not the trainable rank. The trainable rank equals `min(weight.shape) - effective_rank`. This differs from LoRA's `r`, which directly specifies the trainable rank.
111+
112+
110113
## Training Advice for Continual Learning
111114

112115
### Sequential Task Learning
@@ -120,13 +123,13 @@ model = get_peft_model(base_model, OSFConfig(effective_rank=r))
120123
train_task(model, task_1_data)
121124

122125
# Task 2: recompute SVD on updated weights and increase preserved subspace
123-
base_model = model.base_model.model # unwrap updated base
126+
base_model = model.unload() # unwrap base model without assuming internals
124127
r += 4 # grow preserved subspace to include Task 1 knowledge
125128
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
126129
train_task(model, task_2_data)
127130

128131
# Task 3: recompute again and expand preserved subspace further
129-
base_model = model.base_model.model
132+
base_model = model.unload()
130133
r += 4
131134
model = get_peft_model(base_model, OSFConfig(effective_rank=r))
132135
train_task(model, task_3_data)
@@ -146,23 +149,23 @@ This approach ensures each task gets adequate learning capacity while progressiv
146149
```python
147150
# Example: 4-task sequence with progressive budget allocation
148151
n_tasks = 4
149-
base_rank = 32 # Starting rank for full capacity
152+
max_preserved_rank = 512 # Upper bound for preserved rank per target (heuristic)
150153

151154
for task_id in range(n_tasks):
152-
# Calculate remaining capacity for current task
153-
freeze_fraction = task_id / n_tasks
154-
remaining_capacity = 1.0 - freeze_fraction
155-
current_rank = int(base_rank * remaining_capacity)
156-
155+
# Freeze increases over time; trainable capacity shrinks
156+
preserved_fraction = (task_id + 1) / n_tasks
157+
preserved_rank = int(max_preserved_rank * preserved_fraction)
158+
157159
config = OSFConfig(
158160
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
159-
effective_rank=current_rank
161+
effective_rank=preserved_rank,
160162
)
161-
162-
print(f"Task {task_id + 1}: Using rank {current_rank} "
163-
f"({remaining_capacity:.1%} of full capacity)")
164-
165-
# Train on current task
163+
164+
print(
165+
f"Task {task_id + 1}: Preserving rank {preserved_rank} "
166+
f"({preserved_fraction:.1%} of max_preserved_rank - {max_preserved_rank} frozen); trainable rank = min_dim - preserved_rank"
167+
)
168+
166169
model = get_peft_model(base_model, config)
167170
train_task(model, task_data[task_id])
168171
```
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
{
2+
"task_type": null,
3+
"peft_type": "OSF",
4+
"auto_mapping": null,
5+
"base_model_name_or_path": "meta-llama/Llama-3.2-3B",
6+
"revision": null,
7+
"inference_mode": false,
8+
"effective_rank": null,
9+
"target_modules": [
10+
"q_proj",
11+
"k_proj",
12+
"v_proj",
13+
"o_proj",
14+
"gate_proj",
15+
"down_proj",
16+
"up_proj"
17+
],
18+
"rank_pattern": {
19+
"q_proj": 2944,
20+
"o_proj": 2944,
21+
"k_proj": 896,
22+
"v_proj": 896,
23+
"gate_proj": 2944,
24+
"down_proj": 2944,
25+
"up_proj": 2944
26+
}
27+
}
28+
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"optimizer_kwargs": {
3+
"lr": 5e-5
4+
}
5+
}
6+

0 commit comments

Comments
 (0)