|
| 1 | +from collections.abc import Callable |
| 2 | +from typing import Any |
| 3 | + |
| 4 | +import torch |
| 5 | +from torch.linalg import svd |
| 6 | + |
| 7 | +from src.autoencoder import layers |
| 8 | +from src.classic_extractor import ClassicExtractor |
| 9 | +from src.inverse_activation import stable_logit |
| 10 | +from src.svd_classifier import SVDClassifier |
| 11 | +from src.utils import restore_weights_from_svd |
| 12 | + |
| 13 | + |
| 14 | +class ExtractionFramework: |
| 15 | + def __init__( |
| 16 | + self, |
| 17 | + decomposition_clf: SVDClassifier, |
| 18 | + classic_extractor: ClassicExtractor, |
| 19 | + q_matrix: torch.Tensor, |
| 20 | + ) -> None: |
| 21 | + self.decomposition_clf = decomposition_clf |
| 22 | + self.classic_extractor = classic_extractor |
| 23 | + self.q_matrix = q_matrix * 0.1 # Scale initial Q-matrix |
| 24 | + |
| 25 | + # Perform initial decomposition and prediction |
| 26 | + self.u: torch.Tensor |
| 27 | + self.s: torch.Tensor |
| 28 | + self.v: torch.Tensor |
| 29 | + self.u, self.s, self.v = self.decompose() |
| 30 | + self.h = self.decomposition_clf.predict(self.s) |
| 31 | + |
| 32 | + # Initialize placeholders for extracted weights |
| 33 | + self.w_svd: list[torch.Tensor] = [] |
| 34 | + self.w_hat: list[torch.Tensor] = [] |
| 35 | + |
| 36 | + def decompose(self) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor] | Any: |
| 37 | + """Decomposes the Q-matrix using SVD.""" |
| 38 | + return svd(self.q_matrix) |
| 39 | + |
| 40 | + def inverse_activation( |
| 41 | + self, x: torch.Tensor, fun: Callable[..., torch.Tensor] = stable_logit |
| 42 | + ) -> torch.Tensor: |
| 43 | + # src/extraction_framework.py:42: error: Too many arguments [call-arg] |
| 44 | + """Inverse activation function. Inverse sigmoid by default.""" |
| 45 | + self._clamp_eps = 1e-7 |
| 46 | + self._max_abs = 20.0 |
| 47 | + return fun(x, self._clamp_eps, self._max_abs) |
| 48 | + |
| 49 | + def extract(self) -> list[torch.Tensor]: |
| 50 | + """Extracts the weights from the neural network.""" |
| 51 | + h = 0 # Initialize hidden layer dimension tracker |
| 52 | + w_hat_svd = [] # List to store extracted weights via SVD |
| 53 | + |
| 54 | + while h < self.h: |
| 55 | + self.h = h if h != 0 else self.h |
| 56 | + print(f"Hidden size: {self.h}") |
| 57 | + |
| 58 | + self.u, self.s, self.v = self.decompose() |
| 59 | + w_hat_i = restore_weights_from_svd(self.u, self.s, self.h) |
| 60 | + print(f"W_hat_{self.h} shape: {w_hat_i.shape}") |
| 61 | + w_hat_svd.append(w_hat_i) |
| 62 | + |
| 63 | + # Update Q-matrix for next iteration |
| 64 | + print(f"Q_matrix shape before update: {self.q_matrix.shape}") |
| 65 | + |
| 66 | + # Ensure data types match for matrix multiplication |
| 67 | + w_hat_i = w_hat_i.to(self.q_matrix.dtype) |
| 68 | + try: |
| 69 | + self.q_matrix = torch.matmul(w_hat_i.transpose(0, 1), self.q_matrix) |
| 70 | + self.q_matrix = self.inverse_activation(self.q_matrix) |
| 71 | + |
| 72 | + # Add stability check |
| 73 | + if torch.isnan(self.q_matrix).any() or torch.isinf(self.q_matrix).any(): |
| 74 | + self.q_matrix = torch.nan_to_num( |
| 75 | + self.q_matrix, |
| 76 | + nan=0.0, |
| 77 | + posinf=self._max_abs, |
| 78 | + neginf=-self._max_abs, |
| 79 | + ) |
| 80 | + |
| 81 | + except RuntimeError as e: |
| 82 | + print(f"Error during Q-matrix update: {e}") |
| 83 | + # Add recovery mechanism |
| 84 | + self.q_matrix = torch.clamp( |
| 85 | + self.q_matrix, -self._max_abs, self._max_abs |
| 86 | + ) |
| 87 | + break |
| 88 | + |
| 89 | + print(f"Q_matrix shape after update: {self.q_matrix.shape}") |
| 90 | + # Predict next hidden layer dimension |
| 91 | + h = self.decomposition_clf.predict(self.s) |
| 92 | + |
| 93 | + w_hat_svd.reverse() |
| 94 | + |
| 95 | + if not hasattr(self, "w_remaining"): |
| 96 | + self.initialize_remaining_layers() |
| 97 | + |
| 98 | + extracted_layers = self.classic_extractor.extract(self.w_remaining) |
| 99 | + |
| 100 | + print(f"Size of extracted layers: {len(extracted_layers)}") |
| 101 | + for _ in range(len(extracted_layers)): |
| 102 | + print(f"Extracted layer {_ + 1} shape: {extracted_layers[_].shape}") |
| 103 | + |
| 104 | + self.w_hat = extracted_layers + w_hat_svd |
| 105 | + |
| 106 | + return self.w_hat |
| 107 | + |
| 108 | + def initialize_remaining_layers(self) -> None: |
| 109 | + """Initializes remaining layers with random weights.""" |
| 110 | + try: |
| 111 | + remaining_layers_sizes = [ |
| 112 | + (layers[i + 1], layers[i]) for i in range((len(layers) - 1) // 2) |
| 113 | + ] |
| 114 | + self.w_remaining = [ |
| 115 | + torch.randn(size).float() for size in remaining_layers_sizes |
| 116 | + ] |
| 117 | + except Exception as e: # noqa: BLE001 |
| 118 | + print(f"Error initializing remaining layers: {e}") |
| 119 | + |
| 120 | + def _validate_tensor(self, tensor: torch.Tensor) -> torch.Tensor: |
| 121 | + """Ensure numerical validity of matrix.""" |
| 122 | + tensor = torch.nan_to_num(tensor, nan=0.0) |
| 123 | + return torch.clamp(tensor, -self._max_abs, self._max_abs) |
| 124 | + |
| 125 | + def _safe_matmul(self, a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: |
| 126 | + """Stable matrix multiplication with pre-check.""" |
| 127 | + a = self._validate_tensor(a) |
| 128 | + b = self._validate_tensor(b) |
| 129 | + return torch.matmul(a, b) |
0 commit comments