Skip to content
This repository was archived by the owner on Nov 13, 2025. It is now read-only.

Commit 8c22a4c

Browse files
committed
complete uv
1 parent b538a76 commit 8c22a4c

File tree

4 files changed

+1936
-0
lines changed

4 files changed

+1936
-0
lines changed

.python-version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.8.20

pyproject.toml

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
[project]
2+
name = "cara"
3+
version = "0.1.0"
4+
description = "[ICML2025] Canonical Rank Adaptation"
5+
readme = "README.md"
6+
requires-python = "==3.8.20"
7+
dependencies = [
8+
"avalanche-lib==0.3.0",
9+
"matplotlib==3.7.5",
10+
"numpy==1.24.4",
11+
"tensorly==0.8.1",
12+
"timm==0.4.12",
13+
"tqdm==4.66.6",
14+
]
15+
16+
17+
[project.optional-dependencies]
18+
cpu = [
19+
"torch==2.0.0",
20+
"torchvision==0.15",
21+
]
22+
cu118 = [
23+
"torch==2.0.0",
24+
"torchvision==0.15",
25+
]
26+
27+
[tool.uv]
28+
conflicts = [
29+
[
30+
{ extra = "cpu" },
31+
{ extra = "cu118" },
32+
],
33+
]
34+
35+
[tool.uv.sources]
36+
torch = [
37+
{ index = "pytorch-cpu", extra = "cpu" },
38+
{ index = "pytorch-cu118", extra = "cu118" },
39+
]
40+
torchvision = [
41+
{ index = "pytorch-cpu", extra = "cpu" },
42+
{ index = "pytorch-cu118", extra = "cu118" },
43+
]
44+
45+
[[tool.uv.index]]
46+
name = "pytorch-cpu"
47+
url = "https://download.pytorch.org/whl/cpu"
48+
explicit = true
49+
priority = "primary"
50+
51+
[[tool.uv.index]]
52+
name = "pytorch-cu118"
53+
url = "https://download.pytorch.org/whl/cu118"
54+
explicit = true
55+
56+
57+
[tool.isort]
58+
profile = "black"
59+
py_version = 38
60+
61+
[tool.black]
62+
line-length = 79
63+
target-version = ['py38']
64+
65+
[tool.mypy]
66+
python_version = "3.8.20" # or your target Python version
67+
install_types = true
68+
non_interactive = true
69+
ignore_missing_imports = true
70+
strict_optional = false
71+
warn_return_any = false
72+
implicit_reexport = true
73+
allow_untyped_calls = true
74+
explicit_package_bases = true
75+
76+
[dependency-groups]
77+
dev = [
78+
"black",
79+
"darglint",
80+
"flake8",
81+
"flake8-black",
82+
"flake8-broken-line",
83+
"flake8-bugbear",
84+
"flake8-docstrings",
85+
"isort",
86+
"mypy",
87+
"pep8-naming",
88+
"pre-commit",
89+
"pydocstyle",
90+
"types-pyyaml",
91+
"types-requests",
92+
"types-simplejson",
93+
"types-tabulate",
94+
]

src/cara/cara.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from copy import deepcopy
2+
from typing import Dict, Any
3+
4+
import torch as th
5+
import torch.nn as nn
6+
import timm
7+
import tensorly as tl
8+
tl.set_backend("pytorch")
9+
10+
11+
def cp_attn(self, x):
12+
B, N, C = x.shape
13+
qkv = self.qkv(x)
14+
f1 = global_model.CP_A1[self.attn_idx:self.attn_idx+3, :]
15+
tensor_attn = tl.cp_to_tensor((global_model.CP_R1, (f1, global_model.CP_A2, global_model.CP_A3, global_model.CP_A4)))
16+
K, E, H, D = tensor_attn.shape
17+
tensor_attn = tensor_attn.reshape((K, E, H*D))
18+
qkv_delta = th.einsum("bnd, kde->kbne", x, self.dp(tensor_attn))
19+
qkv_delta = qkv_delta.reshape(3, B, N, self.num_heads, C//self.num_heads).permute(
20+
0, 1, 3, 2, 4
21+
)
22+
qkv = qkv.reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(
23+
2, 0, 3, 1, 4
24+
)
25+
qkv += qkv_delta * self.s
26+
q, k, v = qkv[0], qkv[1], qkv[2]
27+
attn = (q @ k.transpose(-2, -1)) * self.scale
28+
attn = attn.softmax(dim=-1)
29+
attn = self.attn_drop(attn)
30+
31+
x = (attn@v).transpose(1, 2).reshape(B, N, C)
32+
33+
proj = self.proj(x)
34+
p1 = global_model.CP_P1[self.idx:self.idx+1, :]
35+
tensor_proj = tl.cp_to_tensor((global_model.CP_R2, (p1, global_model.CP_P2, global_model.CP_P3)))
36+
AA, AB, AC = tensor_proj.shape
37+
tensor_proj = tensor_proj.reshape((AA*AB, AC))
38+
proj_delta = x@self.dp(tensor_proj.T) + global_model.CP_bias1
39+
proj += proj_delta * self.s
40+
x = self.proj_drop(proj)
41+
return x
42+
43+
44+
def cp_mlp(self, x):
45+
p1_up = global_model.CP_P1[self.idx:self.idx+4, :]
46+
p1_down = global_model.CP_P1[self.idx+4: self.idx+8, :]
47+
48+
up = self.fc1(x)
49+
tensor_up = tl.cp_to_tensor((global_model.CP_R2, (p1_up, global_model.CP_P2, global_model.CP_P3)))
50+
AA, AB, AC = tensor_up.shape
51+
tensor_up = tensor_up.reshape((AA*AB, AC))
52+
up_delta = x@self.dp(tensor_up.T) + global_model.CP_bias2
53+
up += up_delta * self.s
54+
55+
x = self.act(up)
56+
x = self.drop(x)
57+
58+
down = self.fc2(x)
59+
tensor_down = tl.cp_to_tensor((global_model.CP_R2, (p1_down, global_model.CP_P2, global_model.CP_P3)))
60+
tensor_down = tensor_down.reshape((AA*AB, AC))
61+
down_delta = x@self.dp(tensor_down) + global_model.CP_bias3
62+
down += down_delta * self.s
63+
x = self.drop(down)
64+
return x
65+
66+
67+
def set_cara(model: nn.Module, rank: int, scale: float, l_mu: float, l_std: float):
68+
if type(model) == timm.models.vision_transformer.VisionTransformer:
69+
# Declare CaRA parameters
70+
model.CP_A1 = nn.Parameter(th.empty([36, rank]), requires_grad=True)
71+
model.CP_A2 = nn.Parameter(th.empty([768, rank]), requires_grad=True)
72+
model.CP_A3 = nn.Parameter(th.empty([12, rank]), requires_grad=True)
73+
model.CP_A4 = nn.Parameter(th.empty([768//12, rank]), requires_grad=True)
74+
model.CP_P1 = nn.Parameter(th.empty([108, rank]), requires_grad=True)
75+
model.CP_P2 = nn.Parameter(th.empty([768, rank]), requires_grad=True)
76+
model.CP_P3 = nn.Parameter(th.empty([768, rank]), requires_grad=True)
77+
model.CP_R1 = nn.Parameter(th.empty([rank]), requires_grad=True)
78+
model.CP_R2 = nn.Parameter(th.empty([rank]), requires_grad=True)
79+
model.CP_bias1 = nn.Parameter(th.empty([768]), requires_grad=True)
80+
model.CP_bias2 = nn.Parameter(th.empty([768*4]), requires_grad=True)
81+
model.CP_bias3 = nn.Parameter(th.empty([768]), requires_grad=True)
82+
# Initialise CaRA parameters
83+
nn.init.xavier_normal_(model.CP_A1)
84+
nn.init.zeros_(model.CP_A2)
85+
nn.init.orthogonal_(model.CP_A3)
86+
nn.init.orthogonal_(model.CP_A4)
87+
nn.init.xavier_normal_(model.CP_P1)
88+
nn.init.zeros_(model.CP_P2)
89+
nn.init.orthogonal_(model.CP_P3)
90+
if l_std != 0.0:
91+
nn.init.normal_(model.CP_R1, mean=l_mu, std=l_std)
92+
nn.init.normal_(model.CP_R2, mean=l_mu, std=l_std)
93+
elif l_mu == 1.0 and l_std == 0.0:
94+
nn.init.ones_(model.CP_R1)
95+
nn.init.ones_(model.CP_R2)
96+
nn.init.zeros_(model.CP_bias1)
97+
nn.init.zeros_(model.CP_bias2)
98+
nn.init.zeros_(model.CP_bias3)
99+
# CaRA indexing
100+
model.idx = 0
101+
model.attn_idx = 0
102+
for child in model.children():
103+
if type(child) == timm.models.vision_transformer.Attention:
104+
child.dp = nn.Dropout(0.1)
105+
child.s = scale
106+
child.dim = rank
107+
child.idx = global_model.idx
108+
child.attn_idx = global_model.attn_idx
109+
global_model.idx += 1
110+
global_model.attn_idx += 3
111+
bound_method = cp_attn.__get__(child, child.__class__)
112+
setattr(child, "forward", bound_method)
113+
elif type(child) == timm.models.layers.mlp.Mlp:
114+
child.dp = nn.Dropout(0.1)
115+
child.s = scale
116+
child.dim = rank
117+
child.idx = global_model.idx
118+
global_model.idx += 8
119+
bound_method = cp_mlp.__get__(child, child.__class__)
120+
setattr(child, "forward", bound_method)
121+
elif len(list(child.children())) != 0:
122+
set_cara(child, rank, scale, l_mu, l_std)
123+
124+
125+
def cara(config):
126+
# CaRA parameters
127+
model = config["model"]
128+
rank = config["rank"]
129+
scale = config["scale"]
130+
l_mu = config["l_mu"]
131+
l_std = config["l_std"]
132+
133+
global global_model
134+
global_model = model
135+
set_cara(model, rank, scale, l_mu, l_std)
136+
return global_model

0 commit comments

Comments
 (0)