Skip to content

Commit 5171089

Browse files
committed
CLN: update custom basis tests
1 parent 1d13c4f commit 5171089

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/basis/test_basis.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -139,15 +139,17 @@ def custom_instance(sample_basis):
139139

140140
def test_fit_method(custom_instance, sample_basis):
141141
"""Test that fit correctly sets basis_matrix_ and returns self."""
142-
result = custom_instance.fit()
142+
X = np.ones(sample_basis.shape)
143+
result = custom_instance.fit(X)
143144
expected_basis_matrix = sample_basis[:, :3]
144145
np.testing.assert_array_equal(custom_instance.basis_matrix_, expected_basis_matrix)
145146
assert result is custom_instance
146147

147148

148149
def test_matrix_inverse_default(custom_instance):
149150
"""Test matrix_inverse with default n_basis_modes."""
150-
custom_instance.fit()
151+
X = np.random.random((10, 10))
152+
custom_instance.fit(X)
151153
result = custom_instance.matrix_inverse()
152154
expected_result = custom_instance.basis_matrix_.T
153155
np.testing.assert_array_equal(result, expected_result)
@@ -156,7 +158,8 @@ def test_matrix_inverse_default(custom_instance):
156158
@pytest.mark.parametrize("n_modes", [1, 2])
157159
def test_matrix_inverse_with_n_basis_modes(custom_instance, n_modes):
158160
"""Test matrix_inverse with specified n_basis_modes."""
159-
custom_instance.fit()
161+
X = np.random.random((10, 10))
162+
custom_instance.fit(X)
160163
result = custom_instance.matrix_inverse(n_basis_modes=n_modes)
161164
expected_result = custom_instance.basis_matrix_[:, :n_modes].T
162165
np.testing.assert_array_equal(result, expected_result)
@@ -183,7 +186,8 @@ def test_n_basis_modes_setter(custom_instance):
183186

184187
def test_matrix_inverse_calls_validate_input(custom_instance, monkeypatch):
185188
"""Test that matrix_inverse calls _validate_input."""
186-
custom_instance.fit()
189+
X = np.random.random((10, 10))
190+
custom_instance.fit(X)
187191
validation_called = False
188192
test_value = None
189193

0 commit comments

Comments
 (0)