Skip to content
This repository was archived by the owner on Nov 13, 2025. It is now read-only.

Commit ff52c03

Browse files
committed
update flake8 and format
1 parent 70f2645 commit ff52c03

File tree

8 files changed

+283
-62
lines changed

8 files changed

+283
-62
lines changed

.flake8

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#########################
2+
# Flake8 Configuration #
3+
# (.flake8) #
4+
#########################
5+
[flake8]
6+
ignore =
7+
# asserts are ok when testing.
8+
S101
9+
# pickle
10+
S301
11+
# pickle
12+
S403
13+
S404
14+
S603
15+
# Line break before binary operator (flake8 is wrong)
16+
W503
17+
# Ignore the spaces black puts before columns.
18+
E203
19+
# allow path extensions for testing.
20+
E402
21+
DAR101
22+
DAR201
23+
# flake and pylance disagree on linebreaks in strings.
24+
N400
25+
N806
26+
exclude =
27+
.tox,
28+
.git,
29+
__pycache__,
30+
docs/source/conf.py,
31+
build,
32+
dist,
33+
tests/fixtures/*,
34+
*.pyc,
35+
*.bib,
36+
*.egg-info,
37+
.cache,
38+
.eggs,
39+
data.
40+
max-line-length = 120
41+
max-complexity = 20
42+
import-order-style = pycharm
43+
application-import-names =
44+
seleqt
45+
tests

.github/workflows/tests.yml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
name: Tests
2+
on: [ push, pull_request ]
3+
jobs:
4+
tests:
5+
name: Tests
6+
runs-on: ubuntu-latest
7+
steps:
8+
- uses: actions/checkout@v4
9+
- name: Install uv
10+
uses: astral-sh/setup-uv@v3
11+
- name: Set up Python 3.8.20
12+
run: uv python install 3.8.20
13+
- name: Install dependencies (CPU only)
14+
run: uv sync --extra cpu --no-dev
15+
- name: Test with pytest
16+
run: uv run pytest
17+
18+
lint:
19+
name: Lint
20+
runs-on: ubuntu-latest
21+
strategy:
22+
matrix:
23+
python-version: ["3.8.20"]
24+
steps:
25+
- uses: actions/checkout@v4
26+
- name: Install uv
27+
uses: astral-sh/setup-uv@v3
28+
- name: Set up Python ${{ matrix.python-version }}
29+
run: uv python install ${{ matrix.python-version }}
30+
- name: Install dependencies
31+
run: uv tool install flake8
32+
- name: Run flake8
33+
run: uv tool run flake8 src/ tests/

.pre-commit-config.yaml

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
repos:
2+
- repo: local
3+
hooks:
4+
- id: isort
5+
name: isort
6+
entry: uv run isort
7+
language: system
8+
args: ["--profile", "black"]
9+
files: ^(src/|tests/)
10+
types: [python]
11+
12+
- id: black
13+
name: black
14+
entry: uv run black
15+
language: system
16+
args: ["--target-version", "py312"]
17+
files: ^(src/|tests/)
18+
types: [python]
19+
20+
- id: flake8
21+
name: flake8
22+
entry: uv run flake8
23+
language: system
24+
files: ^(src/|tests/)$
25+
types: [python]
26+
27+
- id: pytest
28+
name: pytest
29+
entry: uv run pytest
30+
language: system
31+
files: ^(tests/)$
32+
pass_filenames: false
33+
always_run: true

Makefile

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
.PHONY: format check-format test lint check-all typing all
2+
3+
# Format code
4+
format:
5+
uv run isort tests/ src/
6+
uv run black tests/ src/
7+
8+
# Check formatting (without making changes)
9+
check-format:
10+
uv run isort --check-only --diff tests/ src/
11+
uv run black --check --diff tests/ src/
12+
13+
lint:
14+
uv run flake8 tests/ src/
15+
16+
test:
17+
uv run pytest
18+
19+
# Run all pre-commit checks manually
20+
check-all:
21+
uv run pre-commit run --all-files
22+
23+
all: format test lint
24+
@echo "All checks passed!"

src/cara/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""src package init."""
1+
"""src package init."""

src/cara/cara.py

Lines changed: 91 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,42 @@
1-
from copy import deepcopy
2-
from typing import Dict, Any
1+
"""Implement CaRA."""
32

3+
from typing import Any, Dict
4+
5+
import tensorly as tl
6+
import timm
47
import torch as th
58
import torch.nn as nn
6-
import timm
7-
import tensorly as tl
9+
810
tl.set_backend("pytorch")
911

12+
global_model: th.nn.Module
1013

11-
def cp_attn(self, x):
14+
15+
def cp_attn(self, x: th.Tensor) -> th.Tensor:
16+
"""Attention with CP parameters.
17+
18+
Args:
19+
x (th.Tensor): Input tensor.
20+
21+
Returns:
22+
th.Tensor: CaRA attention output.
23+
"""
1224
B, N, C = x.shape
1325
qkv = self.qkv(x)
14-
f1 = global_model.CP_A1[self.attn_idx:self.attn_idx+3, :]
15-
tensor_attn = tl.cp_to_tensor((global_model.CP_R1, (f1, global_model.CP_A2, global_model.CP_A3, global_model.CP_A4)))
26+
f1 = global_model.CP_A1[self.attn_idx : self.attn_idx + 3, :]
27+
tensor_attn = tl.cp_to_tensor(
28+
(
29+
global_model.CP_R1,
30+
(f1, global_model.CP_A2, global_model.CP_A3, global_model.CP_A4),
31+
)
32+
)
1633
K, E, H, D = tensor_attn.shape
17-
tensor_attn = tensor_attn.reshape((K, E, H*D))
34+
tensor_attn = tensor_attn.reshape((K, E, H * D))
1835
qkv_delta = th.einsum("bnd, kde->kbne", x, self.dp(tensor_attn))
19-
qkv_delta = qkv_delta.reshape(3, B, N, self.num_heads, C//self.num_heads).permute(
20-
0, 1, 3, 2, 4
21-
)
22-
qkv = qkv.reshape(B, N, 3, self.num_heads, C//self.num_heads).permute(
36+
qkv_delta = qkv_delta.reshape(
37+
3, B, N, self.num_heads, C // self.num_heads
38+
).permute(0, 1, 3, 2, 4)
39+
qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(
2340
2, 0, 3, 1, 4
2441
)
2542
qkv += qkv_delta * self.s
@@ -28,56 +45,83 @@ def cp_attn(self, x):
2845
attn = attn.softmax(dim=-1)
2946
attn = self.attn_drop(attn)
3047

31-
x = (attn@v).transpose(1, 2).reshape(B, N, C)
48+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
3249

3350
proj = self.proj(x)
34-
p1 = global_model.CP_P1[self.idx:self.idx+1, :]
35-
tensor_proj = tl.cp_to_tensor((global_model.CP_R2, (p1, global_model.CP_P2, global_model.CP_P3)))
51+
p1 = global_model.CP_P1[self.idx : self.idx + 1, :]
52+
tensor_proj = tl.cp_to_tensor(
53+
(global_model.CP_R2, (p1, global_model.CP_P2, global_model.CP_P3))
54+
)
3655
AA, AB, AC = tensor_proj.shape
37-
tensor_proj = tensor_proj.reshape((AA*AB, AC))
38-
proj_delta = x@self.dp(tensor_proj.T) + global_model.CP_bias1
56+
tensor_proj = tensor_proj.reshape((AA * AB, AC))
57+
proj_delta = x @ self.dp(tensor_proj.T) + global_model.CP_bias1
3958
proj += proj_delta * self.s
4059
x = self.proj_drop(proj)
4160
return x
4261

4362

44-
def cp_mlp(self, x):
45-
p1_up = global_model.CP_P1[self.idx:self.idx+4, :]
46-
p1_down = global_model.CP_P1[self.idx+4: self.idx+8, :]
63+
def cp_mlp(self, x: th.Tensor) -> th.Tensor:
64+
"""Mlp with CP parameters.
65+
66+
Args:
67+
x (th.Tensor): Input tensor.
68+
69+
Returns:
70+
th.Tensor: Mlp projected output.
71+
"""
72+
p1_up = global_model.CP_P1[self.idx : self.idx + 4, :]
73+
p1_down = global_model.CP_P1[self.idx + 4 : self.idx + 8, :]
4774

4875
up = self.fc1(x)
49-
tensor_up = tl.cp_to_tensor((global_model.CP_R2, (p1_up, global_model.CP_P2, global_model.CP_P3)))
76+
tensor_up = tl.cp_to_tensor(
77+
(global_model.CP_R2, (p1_up, global_model.CP_P2, global_model.CP_P3))
78+
)
5079
AA, AB, AC = tensor_up.shape
51-
tensor_up = tensor_up.reshape((AA*AB, AC))
52-
up_delta = x@self.dp(tensor_up.T) + global_model.CP_bias2
80+
tensor_up = tensor_up.reshape((AA * AB, AC))
81+
up_delta = x @ self.dp(tensor_up.T) + global_model.CP_bias2
5382
up += up_delta * self.s
5483

5584
x = self.act(up)
5685
x = self.drop(x)
57-
86+
5887
down = self.fc2(x)
59-
tensor_down = tl.cp_to_tensor((global_model.CP_R2, (p1_down, global_model.CP_P2, global_model.CP_P3)))
60-
tensor_down = tensor_down.reshape((AA*AB, AC))
61-
down_delta = x@self.dp(tensor_down) + global_model.CP_bias3
88+
tensor_down = tl.cp_to_tensor(
89+
(global_model.CP_R2, (p1_down, global_model.CP_P2, global_model.CP_P3))
90+
)
91+
tensor_down = tensor_down.reshape((AA * AB, AC))
92+
down_delta = x @ self.dp(tensor_down) + global_model.CP_bias3
6293
down += down_delta * self.s
6394
x = self.drop(down)
6495
return x
6596

6697

67-
def set_cara(model: nn.Module, rank: int, scale: float, l_mu: float, l_std: float):
68-
if type(model) == timm.models.vision_transformer.VisionTransformer:
98+
def set_cara(
99+
model: nn.Module, rank: int, scale: float, l_mu: float, l_std: float
100+
) -> None:
101+
"""Cara setup.
102+
103+
Args:
104+
model (nn.Module): ViT model.
105+
rank (int): FT Rank.
106+
scale (float): FT scale.
107+
l_mu (float): Init lambda_mu.
108+
l_std (float): Init lambda_std.
109+
"""
110+
if type(model) is timm.models.vision_transformer.VisionTransformer:
69111
# Declare CaRA parameters
70112
model.CP_A1 = nn.Parameter(th.empty([36, rank]), requires_grad=True)
71113
model.CP_A2 = nn.Parameter(th.empty([768, rank]), requires_grad=True)
72114
model.CP_A3 = nn.Parameter(th.empty([12, rank]), requires_grad=True)
73-
model.CP_A4 = nn.Parameter(th.empty([768//12, rank]), requires_grad=True)
115+
model.CP_A4 = nn.Parameter(
116+
th.empty([768 // 12, rank]), requires_grad=True
117+
)
74118
model.CP_P1 = nn.Parameter(th.empty([108, rank]), requires_grad=True)
75119
model.CP_P2 = nn.Parameter(th.empty([768, rank]), requires_grad=True)
76120
model.CP_P3 = nn.Parameter(th.empty([768, rank]), requires_grad=True)
77121
model.CP_R1 = nn.Parameter(th.empty([rank]), requires_grad=True)
78122
model.CP_R2 = nn.Parameter(th.empty([rank]), requires_grad=True)
79123
model.CP_bias1 = nn.Parameter(th.empty([768]), requires_grad=True)
80-
model.CP_bias2 = nn.Parameter(th.empty([768*4]), requires_grad=True)
124+
model.CP_bias2 = nn.Parameter(th.empty([768 * 4]), requires_grad=True)
81125
model.CP_bias3 = nn.Parameter(th.empty([768]), requires_grad=True)
82126
# Initialise CaRA parameters
83127
nn.init.xavier_normal_(model.CP_A1)
@@ -100,7 +144,7 @@ def set_cara(model: nn.Module, rank: int, scale: float, l_mu: float, l_std: floa
100144
model.idx = 0
101145
model.attn_idx = 0
102146
for child in model.children():
103-
if type(child) == timm.models.vision_transformer.Attention:
147+
if type(child) is timm.models.vision_transformer.Attention:
104148
child.dp = nn.Dropout(0.1)
105149
child.s = scale
106150
child.dim = rank
@@ -109,28 +153,36 @@ def set_cara(model: nn.Module, rank: int, scale: float, l_mu: float, l_std: floa
109153
global_model.idx += 1
110154
global_model.attn_idx += 3
111155
bound_method = cp_attn.__get__(child, child.__class__)
112-
setattr(child, "forward", bound_method)
113-
elif type(child) == timm.models.layers.mlp.Mlp:
156+
setattr(child, "forward", bound_method) # noqa: B010
157+
elif type(child) is timm.models.layers.mlp.Mlp:
114158
child.dp = nn.Dropout(0.1)
115159
child.s = scale
116160
child.dim = rank
117161
child.idx = global_model.idx
118162
global_model.idx += 8
119163
bound_method = cp_mlp.__get__(child, child.__class__)
120-
setattr(child, "forward", bound_method)
164+
setattr(child, "forward", bound_method) # noqa: B010
121165
elif len(list(child.children())) != 0:
122166
set_cara(child, rank, scale, l_mu, l_std)
123-
124167

125-
def cara(config):
168+
169+
def cara(config: Dict[str, Any]) -> th.nn.Module:
170+
"""Set CaRA for the given configuration.
171+
172+
Args:
173+
config (Dict[str, Any]): Dictionary containing CaRA configuration.
174+
175+
Returns:
176+
th.nn.Module: CaRA model.
177+
"""
126178
# CaRA parameters
127179
model = config["model"]
128180
rank = config["rank"]
129181
scale = config["scale"]
130182
l_mu = config["l_mu"]
131183
l_std = config["l_std"]
132-
184+
133185
global global_model
134186
global_model = model
135187
set_cara(model, rank, scale, l_mu, l_std)
136-
return global_model
188+
return global_model

tests/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
"""Test package."""
1+
"""Test package."""

0 commit comments

Comments
 (0)