11import numpy as np
2- import pandas as pd
32import pytest
43from sklearn .datasets import make_regression
4+ from tests .compare_predictions import _base_check_imputation
55
66from fknni .faiss .faiss import FaissImputer
77
88
9- @pytest .fixture
10- def simple_test_df (rng ):
11- data = pd .DataFrame (rng .integers (0 , 100 , size = (10 , 5 )), columns = list ("ABCDE" ))
12- data_missing = data .copy ()
13- indices = [(i , j ) for i in range (data .shape [0 ]) for j in range (data .shape [1 ])]
14- rng .shuffle (indices )
15- for i , j in indices [:5 ]:
16- data_missing .iat [i , j ] = np .nan
17- return data .to_numpy (), data_missing .to_numpy ()
18-
19-
209@pytest .fixture
2110def regression_dataset (rng ):
2211 X , y = make_regression (n_samples = 100 , n_features = 20 , random_state = 42 )
@@ -28,36 +17,6 @@ def regression_dataset(rng):
2817 return X , X_missing , y
2918
3019
31- def _base_check_imputation (
32- data_original : np .ndarray ,
33- data_imputed : np .ndarray ,
34- ):
35- """Provides the following base checks:
36- - Imputation doesn't leave any NaN behind
37- - Imputation doesn't modify any data that wasn't NaN
38-
39- Args:
40- data_before_imputation: Dataset before imputation
41- data_after_imputation: Dataset after imputation
42-
43- Raises:
44- AssertionError: If any of the checks fail.
45- """
46- if data_original .shape != data_imputed .shape :
47- raise AssertionError ("The shapes of the two datasets do not match" )
48-
49- # Ensure no NaN remains in the imputed dataset
50- if np .isnan (data_imputed ).any ():
51- raise AssertionError ("NaN found in imputed columns of layer_after." )
52-
53- # Ensure imputation does not alter non-NaN values in the imputed columns
54- imputed_non_nan_mask = ~ np .isnan (data_original )
55- if not _are_ndarrays_equal (data_original [imputed_non_nan_mask ], data_imputed [imputed_non_nan_mask ]):
56- raise AssertionError ("Non-NaN values in imputed columns were modified." )
57-
58- return
59-
60-
6120def test_median_imputation (simple_test_df ):
6221 """Tests if median imputation successfully fills all NaN values"""
6322 data , data_missing = simple_test_df
@@ -222,18 +181,3 @@ def test_invalid_temporal_mode():
222181 """Tests if imputer raises error for invalid temporal_mode"""
223182 with pytest .raises (ValueError ):
224183 FaissImputer (temporal_mode = "invalid" )
225-
226-
227- def _are_ndarrays_equal (arr1 : np .ndarray , arr2 : np .ndarray ) -> np .bool_ :
228- """Check if two arrays are equal member-wise.
229-
230- Note: Two NaN are considered equal.
231-
232- Args:
233- arr1: First array to compare
234- arr2: Second array to compare
235-
236- Returns:
237- True if the two arrays are equal member-wise
238- """
239- return np .all (np .equal (arr1 , arr2 , dtype = object ) | ((arr1 != arr1 ) & (arr2 != arr2 )))
0 commit comments