11from __future__ import annotations
22
3- from typing import Literal , Any
4- from lamin_utils import logger
3+ from collections . abc import Sequence
4+ from typing import Any , Literal
55
66import faiss
77import numpy as np
8- from numpy import ndarray , dtype
8+ from lamin_utils import logger
9+ from numpy import dtype
910from sklearn .base import BaseEstimator , TransformerMixin
1011
12+
1113class FaissImputer (BaseEstimator , TransformerMixin ):
1214 """Imputer for completing missing values using Faiss, incorporating weighted averages based on distance."""
1315
@@ -20,6 +22,7 @@ def __init__(
2022 strategy : Literal ["mean" , "median" , "weighted" ] = "mean" ,
2123 index_factory : str = "Flat" ,
2224 min_data_ratio : float = 0.25 ,
25+ temporal_mode : Literal ["flatten" , "per_variable" ] = "flatten" ,
2326 ):
2427 """Initializes FaissImputer with specified parameters that are used for the imputation.
2528
@@ -33,59 +36,93 @@ def __init__(
3336 index_factory: Description of the Faiss index type to build.
3437 min_data_ratio: The minimum (dimension 0) size of the FAISS index relative to the (dimension 0) size of the
3538 dataset that will be used to train FAISS. Defaults to 0.25. See also `fit_transform`.
39+ temporal_mode: How to handle 3D temporal data. 'flatten' treats all (variable, timestep) pairs as
40+ independent features (fast but allows temporal leakage).
41+ 'per_variable' imputes each variable independently across time (slower but respects temporal causality).
3642 """
3743 if n_neighbors < 1 :
3844 raise ValueError ("n_neighbors must be at least 1." )
3945 if strategy not in {"mean" , "median" , "weighted" }:
4046 raise ValueError ("Unknown strategy. Choose one of 'mean', 'median', 'weighted'" )
47+ if temporal_mode not in {"flatten" , "per_variable" }:
48+ raise ValueError ("Unknown temporal_mode. Choose one of 'flatten', 'per_variable'" )
4149
4250 self .missing_values = missing_values
4351 self .n_neighbors = n_neighbors
4452 self .metric = metric
4553 self .strategy = strategy
4654 self .index_factory = index_factory
55+ self .temporal_mode = temporal_mode
4756 self .X_full = None
4857 self .features_nan = None
4958 self .min_data_ratio = min_data_ratio
5059 self .warned_fallback = False
5160 self .warned_unsufficient_neighbors = False
5261 super ().__init__ ()
5362
54- # @override
55- def fit_transform (self , X : np .ndarray , y = None , ** fit_params ) -> ndarray [Any , dtype [Any ]] | None :
63+ def fit_transform ( # noqa: D417
64+ self , X : np .ndarray , y : np .ndarray | None = None , ** fit_params
65+ ) -> np .ndarray [Any , dtype [Any ]] | None :
5666 """Imputes missing values in the data using the fitted Faiss index. This imputation will be performed in place.
57- This imputation will use self.min_data_ratio to check if the index is of sufficient (dimension 0) size to
58- perform a qualitative KNN lookup. If not, it will temporarily exclude enough features to reach this threshold
59- and try again. If an index still can't be built, it will use fallbacks values as defined by self.strategy.
67+
68+ This imputation will use `min_data_ratio` to check if the index is of sufficient (dimension 0) size to perform a qualitative KNN lookup.
69+ If not, it will temporarily exclude enough features to reach this threshold and try again.
70+ If an index still can't be built, it will use fallbacks values as defined by self.strategy.
6071
6172 Args:
62- X: Input data with potential missing values.
73+ X: Input data with potential missing values. Can be 2D (samples × features) or 3D (samples × features × timesteps).
6374 y: Ignored, present for compatibility with sklearn's TransformerMixin.
6475
6576 Returns:
6677 Data with imputed values as a NumPy array of the original data type.
6778 """
79+ original_shape = X .shape
80+
81+ if X .ndim == 3 and self .temporal_mode == "per_variable" :
82+ n_obs , n_vars , n_t = X .shape
83+ result = np .empty_like (X , dtype = np .float64 )
84+ for var_idx in range (n_vars ):
85+ X_slice = X [:, var_idx , :]
86+ result [:, var_idx , :] = self ._impute_2d (X_slice )
87+ return result
88+
89+ if X .ndim == 3 :
90+ n_obs , n_vars , n_t = X .shape
91+ X = X .reshape (n_obs , n_vars * n_t )
92+
93+ result = self ._impute_2d (X )
94+
95+ if len (original_shape ) == 3 :
96+ result = result .reshape (original_shape )
97+
98+ return result
99+
100+ def _impute_2d (self , X : np .ndarray ) -> np .ndarray :
68101 self .X_full = np .asarray (X , dtype = np .float64 ) if not np .issubdtype (X .dtype , np .floating ) else X
69102 if np .isnan (self .X_full ).all (axis = 0 ).any ():
70103 raise ValueError ("Features with only missing values cannot be handled." )
71104
72105 # Prepare fallback values, used to prefill the query vectors nan´s
73106 # or as an imputation fallback if we can't build an index
74107 global_fallbacks_ = (
75- np .nanmean (self .X_full , axis = 0 ) if self .strategy in ["mean" , "weighted" ] else np .nanmedian (self .X_full , axis = 0 )
108+ np .nanmean (self .X_full , axis = 0 )
109+ if self .strategy in ["mean" , "weighted" ]
110+ else np .nanmedian (self .X_full , axis = 0 )
76111 )
77112
78113 # We will need to impute all features having nan´s
79114 feature_indices_to_impute = [i for i in range (self .X_full .shape [1 ]) if np .isnan (self .X_full [:, i ]).any ()]
80115
81116 # Now impute iteratively
82117 while feature_indices_to_impute :
83- feature_indices_being_imputed , training_indices , training_data , index = self ._fit_train_imputer (feature_indices_to_impute )
118+ feature_indices_being_imputed , training_indices , training_data , index = self ._fit_train_imputer (
119+ feature_indices_to_impute
120+ )
84121
85122 # Use fallback data if we can't build an index and iterate again
86123 if index is None :
87124 self ._warn_fallback ()
88- self .X_full [:, feature_indices_being_imputed ] = global_fallbacks_ [feature_indices_being_imputed ]
125+ self .X_full [:, feature_indices_being_imputed ] = global_fallbacks_ [feature_indices_being_imputed ]
89126 continue
90127
91128 # Extract the features from X that was used to train FAISS, and compute the sparseness matrix
@@ -106,7 +143,9 @@ def fit_transform(self, X: np.ndarray, y=None, **fit_params) -> ndarray[Any, dty
106143 # Call FAISS and retrieve data
107144 distances , indices = index .search (sample .reshape (1 , - 1 ), self .n_neighbors )
108145 assert len (indices [0 ]) == self .n_neighbors
109- valid_indices = indices [0 ][indices [0 ] >= 0 ] # Filter out negative indices because they are FAISS error codes
146+ valid_indices = indices [0 ][
147+ indices [0 ] >= 0
148+ ] # Filter out negative indices because they are FAISS error codes
110149
111150 # FAISS couldn't find any neighbor, use fallback values, and go to next row
112151 if len (valid_indices ) == 0 :
@@ -117,8 +156,9 @@ def fit_transform(self, X: np.ndarray, y=None, **fit_params) -> ndarray[Any, dty
117156 # FAISS couldn't find the amount of requested neighbors, warn user and proceed
118157 if len (valid_indices ) < self .n_neighbors :
119158 if not self .warned_unsufficient_neighbors :
120- logger .warning (f"FAISS couldn't find all the requested neighbors. "
121- f"This warning will be displayed only once." )
159+ logger .warning (
160+ "FAISS couldn't find all the requested neighbors. This warning will be displayed only once."
161+ )
122162 self .warned_unsufficient_neighbors = True
123163
124164 # Apply strategy on neighbors data
@@ -139,20 +179,24 @@ def fit_transform(self, X: np.ndarray, y=None, **fit_params) -> ndarray[Any, dty
139179 self .X_full [:, feature_indices_being_imputed ] = x_imputed [:, np .arange (len (feature_indices_being_imputed ))]
140180
141181 # Remove the imputed features from the to-do list
142- feature_indices_to_impute = [feature_indice for feature_indice in feature_indices_to_impute if feature_indice not in feature_indices_being_imputed ]
182+ feature_indices_to_impute = [
183+ feature_indice
184+ for feature_indice in feature_indices_to_impute
185+ if feature_indice not in feature_indices_being_imputed
186+ ]
143187
144188 assert not np .isnan (self .X_full ).any ()
145189 return self .X_full
146190
147- def _fit_train_imputer (self , features_indices : list [int ]) -> (list [int ],
148- list [int ] | None ,
149- np .ndarray | None ,
150- faiss .Index | None ):
191+ def _fit_train_imputer (
192+ self , features_indices : Sequence [int ]
193+ ) -> tuple [list [int ], list [int ] | None , np .ndarray | None , faiss .Index | None ]:
151194 features_indices_to_impute = features_indices .copy ()
152195
153196 # See what features are already imputed
154- already_imputed_features_indices = [i for i in range (self .X_full .shape [1 ])
155- if not np .isnan (self .X_full [:, i ]).any ()]
197+ already_imputed_features_indices = [
198+ i for i in range (self .X_full .shape [1 ]) if not np .isnan (self .X_full [:, i ]).any ()
199+ ]
156200
157201 while True :
158202 # Train data features are those indexed by features_indices AND those already fully imputed in
@@ -184,7 +228,7 @@ def _features_indices_sorted_descending_on_nan(self) -> list[int]:
184228 self .features_nan = sorted (
185229 (i for i in range (self .X_full .shape [1 ]) if np .isnan (self .X_full [:, i ]).sum () > 0 ),
186230 key = lambda i : np .isnan (self .X_full [:, i ]).sum (),
187- reverse = True
231+ reverse = True ,
188232 )
189233
190234 return self .features_nan
@@ -198,7 +242,7 @@ def _train(self, x_train: np.ndarray) -> faiss.Index:
198242
199243 def _warn_fallback (self ):
200244 if not self .warned_fallback :
201- logger .warning (f"Fallback data (as defined by passed strategy) were used. "
202- f"This warning will only be displayed once." )
245+ logger .warning (
246+ "Fallback data (as defined by passed strategy) were used. This warning will only be displayed once."
247+ )
203248 self .warned_fallback = True
204-
0 commit comments