Skip to content

Commit 960d971

Browse files
committed
trying to have the basis matrix as the identity matrix
1 parent 1cbcc13 commit 960d971

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

pysensors/basis/_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@ def matrix_representation(self, n_basis_modes=None, copy=False):
4343
n_basis_modes = self._validate_input(n_basis_modes)
4444

4545
if copy:
46-
return self.basis_matrix_[:, :n_basis_modes].copy()
46+
return self.basis_matrix_[:, :n_basis_modes].copy()#self.original_data @
4747
else:
48-
return self.basis_matrix_[:, :n_basis_modes]
48+
return self.basis_matrix_[:, :n_basis_modes]#self.original_data @
4949

5050
def _validate_input(self, n_basis_modes):
5151
"""

pysensors/basis/_identity.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from warnings import warn
77

88
from numpy import identity
9+
import numpy as np
910
from sklearn.base import BaseEstimator
1011
from sklearn.utils import check_array
1112

@@ -52,7 +53,8 @@ def fit(self, X):
5253
-------
5354
self : instance
5455
"""
55-
56+
# Store original data
57+
self.original_data = X
5658
# Note that we take a transpose here, so columns correspond to examples
5759
if self.n_basis_modes is None:
5860
self.basis_matrix_ = check_array(X).T.copy()
@@ -65,10 +67,10 @@ def fit(self, X):
6567
)
6668
)
6769

68-
self.basis_matrix_ = check_array(X)[: self.n_basis_modes, :].T.copy()
70+
self.basis_matrix_ = np.eye(X.shape[1])[:,:self.n_basis_modes] #check_array(X)[: self.n_basis_modes, :].T.copy()
6971

70-
if self.n_basis_modes < X.shape[0]:
71-
warn(f"Only the first {self.n_basis_modes} examples were retained.")
72+
# if self.n_basis_modes < X.shape[0]:
73+
# warn(f"Only the first {self.n_basis_modes} examples were retained.")
7274
return self
7375

7476
def matrix_inverse(self, n_basis_modes=None):

0 commit comments

Comments
 (0)