Skip to content

Commit 72f9e12

Browse files
authored
1 hybrid extraction framework (#2)
* first version * fix mypy & minor adjust * setup basic ci
1 parent 6fc9b6f commit 72f9e12

File tree

11 files changed

+335
-35
lines changed

11 files changed

+335
-35
lines changed

.github/actions/setup/action.yml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
runs:
2+
using: composite
3+
steps:
4+
- name: "Install UV"
5+
shell: bash
6+
run: |
7+
curl -LsSf https://astral.sh/uv/install.sh | sh

.github/workflows/code-quality.yml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
name: Python Code Quality
2+
3+
on:
4+
pull_request:
5+
push:
6+
branches:
7+
- main
8+
9+
jobs:
10+
lock_file:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
- uses: ./.github/actions/setup
15+
- run: uv lock --locked
16+
17+
linting:
18+
runs-on: ubuntu-latest
19+
needs: [lock_file]
20+
steps:
21+
- uses: actions/checkout@v4
22+
- uses: ./.github/actions/setup
23+
- run: uvx ruff check src/
24+
25+
formatting:
26+
runs-on: ubuntu-latest
27+
needs: [lock_file]
28+
steps:
29+
- uses: actions/checkout@v4
30+
- uses: ./.github/actions/setup
31+
- run: uvx ruff format --check src/
32+
33+
type_consistency:
34+
runs-on: ubuntu-latest
35+
needs: [lock_file]
36+
steps:
37+
- uses: actions/checkout@v4
38+
- uses: ./.github/actions/setup
39+
- run: uv run mypy src/
40+
41+
build:
42+
runs-on: [ubuntu-latest]
43+
needs: [linting, formatting, type_consistency]
44+
steps:
45+
- uses: actions/checkout@v4
46+
- uses: ./.github/actions/setup
47+
- run: uv build

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,8 @@ format:
22
uv run ruff format
33
uv run ruff check --fix src/
44

5+
check_format:
6+
uv run ruff check src
7+
58
mypy:
6-
uv run mypy src/ --namespace-packages --explicit-package-bases
9+
uv run mypy src/

src/autoencoder.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
from typing import Any
2+
3+
import torch
4+
from torch import nn
5+
6+
layers = [1024, 512, 256, 128, 256, 512, 1024]
7+
8+
9+
class Autoencoder(nn.Module):
10+
def __init__(self) -> None:
11+
super().__init__()
12+
13+
encoder_layers: list[nn.Module] = []
14+
for i in range(len(layers) // 2):
15+
encoder_layers.append(nn.Linear(layers[i], layers[i + 1]))
16+
encoder_layers.append(nn.Sigmoid())
17+
self.encoder = nn.Sequential(*encoder_layers)
18+
19+
decoder_layers: list[nn.Module] = []
20+
for i in range(len(layers) // 2, len(layers) - 1):
21+
decoder_layers.append(nn.Linear(layers[i], layers[i + 1]))
22+
decoder_layers.append(nn.Sigmoid())
23+
self.decoder = nn.Sequential(*decoder_layers)
24+
25+
def forward(self, x: torch.Tensor) -> torch.Tensor | Any:
26+
x = self.encoder(x)
27+
return self.decoder(x)
28+
29+
def print_weight_shapes(self) -> None:
30+
print("Encoder layers:")
31+
for layer in self.encoder:
32+
if isinstance(layer, nn.Linear):
33+
print(layer.weight.shape)
34+
35+
print("Decoder layers:")
36+
for layer in self.decoder:
37+
if isinstance(layer, nn.Linear):
38+
print(layer.weight.shape)

src/classic_extractor.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
import torch
2+
3+
4+
class ClassicExtractor:
5+
def __init__(self) -> None: ...
6+
7+
def extract(self, x: list[torch.Tensor]) -> list[torch.Tensor]:
8+
return x

src/extraction_framework.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from collections.abc import Callable
2+
from typing import Any
3+
4+
import torch
5+
from torch.linalg import svd
6+
7+
from src.autoencoder import layers
8+
from src.classic_extractor import ClassicExtractor
9+
from src.inverse_activation import stable_logit
10+
from src.svd_classifier import SVDClassifier
11+
from src.utils import restore_weights_from_svd
12+
13+
14+
class ExtractionFramework:
15+
def __init__(
16+
self,
17+
decomposition_clf: SVDClassifier,
18+
classic_extractor: ClassicExtractor,
19+
q_matrix: torch.Tensor,
20+
) -> None:
21+
self.decomposition_clf = decomposition_clf
22+
self.classic_extractor = classic_extractor
23+
self.q_matrix = q_matrix * 0.1 # Scale initial Q-matrix
24+
25+
# Perform initial decomposition and prediction
26+
self.u: torch.Tensor
27+
self.s: torch.Tensor
28+
self.v: torch.Tensor
29+
self.u, self.s, self.v = self.decompose()
30+
self.h = self.decomposition_clf.predict(self.s)
31+
32+
# Initialize placeholders for extracted weights
33+
self.w_svd: list[torch.Tensor] = []
34+
self.w_hat: list[torch.Tensor] = []
35+
36+
def decompose(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | Any:
37+
"""Decomposes the Q-matrix using SVD."""
38+
return svd(self.q_matrix)
39+
40+
def inverse_activation(
41+
self, x: torch.Tensor, fun: Callable[..., torch.Tensor] = stable_logit
42+
) -> torch.Tensor:
43+
# src/extraction_framework.py:42: error: Too many arguments [call-arg]
44+
"""Inverse activation function. Inverse sigmoid by default."""
45+
self._clamp_eps = 1e-7
46+
self._max_abs = 20.0
47+
return fun(x, self._clamp_eps, self._max_abs)
48+
49+
def extract(self) -> list[torch.Tensor]:
50+
"""Extracts the weights from the neural network."""
51+
h = 0 # Initialize hidden layer dimension tracker
52+
w_hat_svd = [] # List to store extracted weights via SVD
53+
54+
while h < self.h:
55+
self.h = h if h != 0 else self.h
56+
print(f"Hidden size: {self.h}")
57+
58+
self.u, self.s, self.v = self.decompose()
59+
w_hat_i = restore_weights_from_svd(self.u, self.s, self.h)
60+
print(f"W_hat_{self.h} shape: {w_hat_i.shape}")
61+
w_hat_svd.append(w_hat_i)
62+
63+
# Update Q-matrix for next iteration
64+
print(f"Q_matrix shape before update: {self.q_matrix.shape}")
65+
66+
# Ensure data types match for matrix multiplication
67+
w_hat_i = w_hat_i.to(self.q_matrix.dtype)
68+
try:
69+
self.q_matrix = torch.matmul(w_hat_i.transpose(0, 1), self.q_matrix)
70+
self.q_matrix = self.inverse_activation(self.q_matrix)
71+
72+
# Add stability check
73+
if torch.isnan(self.q_matrix).any() or torch.isinf(self.q_matrix).any():
74+
self.q_matrix = torch.nan_to_num(
75+
self.q_matrix,
76+
nan=0.0,
77+
posinf=self._max_abs,
78+
neginf=-self._max_abs,
79+
)
80+
81+
except RuntimeError as e:
82+
print(f"Error during Q-matrix update: {e}")
83+
# Add recovery mechanism
84+
self.q_matrix = torch.clamp(
85+
self.q_matrix, -self._max_abs, self._max_abs
86+
)
87+
break
88+
89+
print(f"Q_matrix shape after update: {self.q_matrix.shape}")
90+
# Predict next hidden layer dimension
91+
h = self.decomposition_clf.predict(self.s)
92+
93+
w_hat_svd.reverse()
94+
95+
if not hasattr(self, "w_remaining"):
96+
self.initialize_remaining_layers()
97+
98+
extracted_layers = self.classic_extractor.extract(self.w_remaining)
99+
100+
print(f"Size of extracted layers: {len(extracted_layers)}")
101+
for _ in range(len(extracted_layers)):
102+
print(f"Extracted layer {_ + 1} shape: {extracted_layers[_].shape}")
103+
104+
self.w_hat = extracted_layers + w_hat_svd
105+
106+
return self.w_hat
107+
108+
def initialize_remaining_layers(self) -> None:
109+
"""Initializes remaining layers with random weights."""
110+
try:
111+
remaining_layers_sizes = [
112+
(layers[i + 1], layers[i]) for i in range((len(layers) - 1) // 2)
113+
]
114+
self.w_remaining = [
115+
torch.randn(size).float() for size in remaining_layers_sizes
116+
]
117+
except Exception as e: # noqa: BLE001
118+
print(f"Error initializing remaining layers: {e}")
119+
120+
def _validate_tensor(self, tensor: torch.Tensor) -> torch.Tensor:
121+
"""Ensure numerical validity of matrix."""
122+
tensor = torch.nan_to_num(tensor, nan=0.0)
123+
return torch.clamp(tensor, -self._max_abs, self._max_abs)
124+
125+
def _safe_matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
126+
"""Stable matrix multiplication with pre-check."""
127+
a = self._validate_tensor(a)
128+
b = self._validate_tensor(b)
129+
return torch.matmul(a, b)

src/inverse_activation.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import torch
2+
3+
4+
def stable_logit(x: torch.Tensor, eps: float, max_abs: float) -> torch.Tensor:
5+
"""Numerically stable inverse sigmoid with gradient control."""
6+
# Clamp probabilities to safe range
7+
x_clamped = torch.clamp(x, eps, 1 - eps)
8+
9+
# Compute logit with separate log operations
10+
logit = torch.log(x_clamped) - torch.log1p(-x_clamped) # More stable than x/(1-x)
11+
12+
# Prevent gradient explosion
13+
return torch.clamp(logit, max_abs, max_abs)

src/main.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
from src.autoencoder import Autoencoder
2+
from src.classic_extractor import ClassicExtractor
3+
from src.extraction_framework import ExtractionFramework
4+
from src.svd_classifier import SVDClassifier
5+
from src.utils import generate_q_matrix, scale_tensor
6+
7+
if __name__ == "__main__":
8+
clf = SVDClassifier()
9+
extractor = ClassicExtractor()
10+
model = Autoencoder()
11+
12+
Q = generate_q_matrix(model, 1024, 2000)
13+
Q = scale_tensor(Q)
14+
15+
framework = ExtractionFramework(clf, extractor, Q)
16+
17+
W = framework.extract()
18+
# Hidden size: 512
19+
# W_hat_512 shape: torch.Size([1024, 512])
20+
# Q_matrix shape before update: torch.Size([1024, 2000])
21+
# Q_matrix shape after update: torch.Size([512, 2000])
22+
# Hidden size: 256
23+
# W_hat_256 shape: torch.Size([512, 256])
24+
# Q_matrix shape before update: torch.Size([512, 2000])
25+
# Q_matrix shape after update: torch.Size([256, 2000])
26+
# Hidden size: 128
27+
# W_hat_128 shape: torch.Size([256, 128])
28+
# Q_matrix shape before update: torch.Size([256, 2000])
29+
# Q_matrix shape after update: torch.Size([128, 2000])
30+
# Size of extracted layers: 3
31+
# Extracted layer 1 shape: torch.Size([512, 1024])
32+
# Extracted layer 2 shape: torch.Size([256, 512])
33+
# Extracted layer 3 shape: torch.Size([128, 256])
34+
35+
for w in W:
36+
print(w.size())
37+
# torch.Size([512, 1024])
38+
# torch.Size([256, 512])
39+
# torch.Size([128, 256])
40+
# torch.Size([256, 128])
41+
# torch.Size([512, 256])
42+
# torch.Size([1024, 512])
43+
44+
model.print_weight_shapes()
45+
# Encoder layers:
46+
# torch.Size([512, 1024])
47+
# torch.Size([256, 512])
48+
# torch.Size([128, 256])
49+
# Decoder layers:
50+
# torch.Size([256, 128])
51+
# torch.Size([512, 256])
52+
# torch.Size([1024, 512])

src/softmax_bottleneck_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor | Any:
6363
x = self.fc1(x)
6464
x = self.softmax(x)
6565
return self.fc2(x)
66+
67+
68+
# TODO: use only torch tensors # noqa: FIX002

src/svd_classifier.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
import torch
2+
3+
from src.autoencoder import layers
4+
5+
6+
class SVDClassifier:
7+
def __init__(self) -> None:
8+
self.state = 1
9+
10+
def predict(self, sigma: torch.Tensor) -> int: # noqa: ARG002
11+
self.state += 1
12+
return layers[self.state - 1] # it always return ground truth value

0 commit comments

Comments
 (0)