Skip to content

Commit 13c4281

Browse files
Zethsonscverse-bot
andauthored
Add support for 3D numpy (#14)
Signed-off-by: Lukas Heumos <lukas.heumos@posteo.net> Co-authored-by: scverse-bot <108668866+scverse-bot@users.noreply.github.com>
1 parent d2323f8 commit 13c4281

File tree

2 files changed

+131
-28
lines changed

2 files changed

+131
-28
lines changed

src/fknni/faiss/faiss.py

Lines changed: 70 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
from __future__ import annotations
22

3-
from typing import Literal, Any
4-
from lamin_utils import logger
3+
from collections.abc import Sequence
4+
from typing import Any, Literal
55

66
import faiss
77
import numpy as np
8-
from numpy import ndarray, dtype
8+
from lamin_utils import logger
9+
from numpy import dtype
910
from sklearn.base import BaseEstimator, TransformerMixin
1011

12+
1113
class FaissImputer(BaseEstimator, TransformerMixin):
1214
"""Imputer for completing missing values using Faiss, incorporating weighted averages based on distance."""
1315

@@ -20,6 +22,7 @@ def __init__(
2022
strategy: Literal["mean", "median", "weighted"] = "mean",
2123
index_factory: str = "Flat",
2224
min_data_ratio: float = 0.25,
25+
temporal_mode: Literal["flatten", "per_variable"] = "flatten",
2326
):
2427
"""Initializes FaissImputer with specified parameters that are used for the imputation.
2528
@@ -33,59 +36,93 @@ def __init__(
3336
index_factory: Description of the Faiss index type to build.
3437
min_data_ratio: The minimum (dimension 0) size of the FAISS index relative to the (dimension 0) size of the
3538
dataset that will be used to train FAISS. Defaults to 0.25. See also `fit_transform`.
39+
temporal_mode: How to handle 3D temporal data. 'flatten' treats all (variable, timestep) pairs as
40+
independent features (fast but allows temporal leakage).
41+
'per_variable' imputes each variable independently across time (slower but respects temporal causality).
3642
"""
3743
if n_neighbors < 1:
3844
raise ValueError("n_neighbors must be at least 1.")
3945
if strategy not in {"mean", "median", "weighted"}:
4046
raise ValueError("Unknown strategy. Choose one of 'mean', 'median', 'weighted'")
47+
if temporal_mode not in {"flatten", "per_variable"}:
48+
raise ValueError("Unknown temporal_mode. Choose one of 'flatten', 'per_variable'")
4149

4250
self.missing_values = missing_values
4351
self.n_neighbors = n_neighbors
4452
self.metric = metric
4553
self.strategy = strategy
4654
self.index_factory = index_factory
55+
self.temporal_mode = temporal_mode
4756
self.X_full = None
4857
self.features_nan = None
4958
self.min_data_ratio = min_data_ratio
5059
self.warned_fallback = False
5160
self.warned_unsufficient_neighbors = False
5261
super().__init__()
5362

54-
# @override
55-
def fit_transform(self, X: np.ndarray, y=None, **fit_params) -> ndarray[Any, dtype[Any]] | None:
63+
def fit_transform( # noqa: D417
64+
self, X: np.ndarray, y: np.ndarray | None = None, **fit_params
65+
) -> np.ndarray[Any, dtype[Any]] | None:
5666
"""Imputes missing values in the data using the fitted Faiss index. This imputation will be performed in place.
57-
This imputation will use self.min_data_ratio to check if the index is of sufficient (dimension 0) size to
58-
perform a qualitative KNN lookup. If not, it will temporarily exclude enough features to reach this threshold
59-
and try again. If an index still can't be built, it will use fallbacks values as defined by self.strategy.
67+
68+
This imputation will use `min_data_ratio` to check if the index is of sufficient (dimension 0) size to perform a qualitative KNN lookup.
69+
If not, it will temporarily exclude enough features to reach this threshold and try again.
70+
If an index still can't be built, it will use fallbacks values as defined by self.strategy.
6071
6172
Args:
62-
X: Input data with potential missing values.
73+
X: Input data with potential missing values. Can be 2D (samples × features) or 3D (samples × features × timesteps).
6374
y: Ignored, present for compatibility with sklearn's TransformerMixin.
6475
6576
Returns:
6677
Data with imputed values as a NumPy array of the original data type.
6778
"""
79+
original_shape = X.shape
80+
81+
if X.ndim == 3 and self.temporal_mode == "per_variable":
82+
n_obs, n_vars, n_t = X.shape
83+
result = np.empty_like(X, dtype=np.float64)
84+
for var_idx in range(n_vars):
85+
X_slice = X[:, var_idx, :]
86+
result[:, var_idx, :] = self._impute_2d(X_slice)
87+
return result
88+
89+
if X.ndim == 3:
90+
n_obs, n_vars, n_t = X.shape
91+
X = X.reshape(n_obs, n_vars * n_t)
92+
93+
result = self._impute_2d(X)
94+
95+
if len(original_shape) == 3:
96+
result = result.reshape(original_shape)
97+
98+
return result
99+
100+
def _impute_2d(self, X: np.ndarray) -> np.ndarray:
68101
self.X_full = np.asarray(X, dtype=np.float64) if not np.issubdtype(X.dtype, np.floating) else X
69102
if np.isnan(self.X_full).all(axis=0).any():
70103
raise ValueError("Features with only missing values cannot be handled.")
71104

72105
# Prepare fallback values, used to prefill the query vectors nan´s
73106
# or as an imputation fallback if we can't build an index
74107
global_fallbacks_ = (
75-
np.nanmean(self.X_full, axis=0) if self.strategy in ["mean", "weighted"] else np.nanmedian(self.X_full, axis=0)
108+
np.nanmean(self.X_full, axis=0)
109+
if self.strategy in ["mean", "weighted"]
110+
else np.nanmedian(self.X_full, axis=0)
76111
)
77112

78113
# We will need to impute all features having nan´s
79114
feature_indices_to_impute = [i for i in range(self.X_full.shape[1]) if np.isnan(self.X_full[:, i]).any()]
80115

81116
# Now impute iteratively
82117
while feature_indices_to_impute:
83-
feature_indices_being_imputed, training_indices, training_data, index = self._fit_train_imputer(feature_indices_to_impute)
118+
feature_indices_being_imputed, training_indices, training_data, index = self._fit_train_imputer(
119+
feature_indices_to_impute
120+
)
84121

85122
# Use fallback data if we can't build an index and iterate again
86123
if index is None:
87124
self._warn_fallback()
88-
self.X_full[:, feature_indices_being_imputed] = global_fallbacks_[feature_indices_being_imputed]
125+
self.X_full[:, feature_indices_being_imputed] = global_fallbacks_[feature_indices_being_imputed]
89126
continue
90127

91128
# Extract the features from X that was used to train FAISS, and compute the sparseness matrix
@@ -106,7 +143,9 @@ def fit_transform(self, X: np.ndarray, y=None, **fit_params) -> ndarray[Any, dty
106143
# Call FAISS and retrieve data
107144
distances, indices = index.search(sample.reshape(1, -1), self.n_neighbors)
108145
assert len(indices[0]) == self.n_neighbors
109-
valid_indices = indices[0][indices[0] >= 0] # Filter out negative indices because they are FAISS error codes
146+
valid_indices = indices[0][
147+
indices[0] >= 0
148+
] # Filter out negative indices because they are FAISS error codes
110149

111150
# FAISS couldn't find any neighbor, use fallback values, and go to next row
112151
if len(valid_indices) == 0:
@@ -117,8 +156,9 @@ def fit_transform(self, X: np.ndarray, y=None, **fit_params) -> ndarray[Any, dty
117156
# FAISS couldn't find the amount of requested neighbors, warn user and proceed
118157
if len(valid_indices) < self.n_neighbors:
119158
if not self.warned_unsufficient_neighbors:
120-
logger.warning(f"FAISS couldn't find all the requested neighbors. "
121-
f"This warning will be displayed only once.")
159+
logger.warning(
160+
"FAISS couldn't find all the requested neighbors. This warning will be displayed only once."
161+
)
122162
self.warned_unsufficient_neighbors = True
123163

124164
# Apply strategy on neighbors data
@@ -139,20 +179,24 @@ def fit_transform(self, X: np.ndarray, y=None, **fit_params) -> ndarray[Any, dty
139179
self.X_full[:, feature_indices_being_imputed] = x_imputed[:, np.arange(len(feature_indices_being_imputed))]
140180

141181
# Remove the imputed features from the to-do list
142-
feature_indices_to_impute = [feature_indice for feature_indice in feature_indices_to_impute if feature_indice not in feature_indices_being_imputed]
182+
feature_indices_to_impute = [
183+
feature_indice
184+
for feature_indice in feature_indices_to_impute
185+
if feature_indice not in feature_indices_being_imputed
186+
]
143187

144188
assert not np.isnan(self.X_full).any()
145189
return self.X_full
146190

147-
def _fit_train_imputer(self, features_indices: list[int]) -> (list[int],
148-
list[int] | None,
149-
np.ndarray | None,
150-
faiss.Index | None):
191+
def _fit_train_imputer(
192+
self, features_indices: Sequence[int]
193+
) -> tuple[list[int], list[int] | None, np.ndarray | None, faiss.Index | None]:
151194
features_indices_to_impute = features_indices.copy()
152195

153196
# See what features are already imputed
154-
already_imputed_features_indices = [i for i in range(self.X_full.shape[1])
155-
if not np.isnan(self.X_full[:, i]).any()]
197+
already_imputed_features_indices = [
198+
i for i in range(self.X_full.shape[1]) if not np.isnan(self.X_full[:, i]).any()
199+
]
156200

157201
while True:
158202
# Train data features are those indexed by features_indices AND those already fully imputed in
@@ -184,7 +228,7 @@ def _features_indices_sorted_descending_on_nan(self) -> list[int]:
184228
self.features_nan = sorted(
185229
(i for i in range(self.X_full.shape[1]) if np.isnan(self.X_full[:, i]).sum() > 0),
186230
key=lambda i: np.isnan(self.X_full[:, i]).sum(),
187-
reverse=True
231+
reverse=True,
188232
)
189233

190234
return self.features_nan
@@ -198,7 +242,7 @@ def _train(self, x_train: np.ndarray) -> faiss.Index:
198242

199243
def _warn_fallback(self):
200244
if not self.warned_fallback:
201-
logger.warning(f"Fallback data (as defined by passed strategy) were used. "
202-
f"This warning will only be displayed once.")
245+
logger.warning(
246+
"Fallback data (as defined by passed strategy) were used. This warning will only be displayed once."
247+
)
203248
self.warned_fallback = True
204-

tests/test_faiss_imputation.py

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
from typing import Any
21
import numpy as np
32
import pandas as pd
43
import pytest
54
from sklearn.datasets import make_regression
5+
66
from fknni.faiss.faiss import FaissImputer
77

88

@@ -55,7 +55,6 @@ def _base_check_imputation(
5555
if not _are_ndarrays_equal(data_original[imputed_non_nan_mask], data_imputed[imputed_non_nan_mask]):
5656
raise AssertionError("Non-NaN values in imputed columns were modified.")
5757

58-
# If reaching here: all checks passed
5958
return
6059

6160

@@ -165,6 +164,66 @@ def test_no_full_rows():
165164
_base_check_imputation(arr_original, arr)
166165

167166

167+
def test_3d_flatten_imputation(rng):
168+
"""Tests if 3D imputation with flatten mode successfully fills all NaN values"""
169+
data_3d = rng.uniform(0, 100, size=(10, 5, 3))
170+
data_missing = data_3d.copy()
171+
indices = [
172+
(i, j, k) for i in range(data_3d.shape[0]) for j in range(data_3d.shape[1]) for k in range(data_3d.shape[2])
173+
]
174+
rng.shuffle(indices)
175+
for i, j, k in indices[:20]:
176+
data_missing[i, j, k] = np.nan
177+
178+
data_original = data_missing.copy()
179+
FaissImputer(n_neighbors=5, temporal_mode="flatten").fit_transform(data_missing)
180+
_base_check_imputation(data_original, data_missing)
181+
assert data_missing.shape == (10, 5, 3)
182+
183+
184+
def test_3d_per_variable_imputation(rng):
185+
"""Tests if 3D imputation with per_variable mode successfully fills all NaN values"""
186+
data_3d = rng.uniform(0, 100, size=(10, 5, 3))
187+
data_missing = data_3d.copy()
188+
indices = [
189+
(i, j, k) for i in range(data_3d.shape[0]) for j in range(data_3d.shape[1]) for k in range(data_3d.shape[2])
190+
]
191+
rng.shuffle(indices)
192+
for i, j, k in indices[:20]:
193+
data_missing[i, j, k] = np.nan
194+
195+
data_original = data_missing.copy()
196+
FaissImputer(n_neighbors=5, temporal_mode="per_variable").fit_transform(data_missing)
197+
_base_check_imputation(data_original, data_missing)
198+
assert data_missing.shape == (10, 5, 3)
199+
200+
201+
def test_3d_modes_produce_different_results(rng):
202+
"""Tests if flatten and per_variable modes produce different results"""
203+
data_3d = rng.uniform(0, 100, size=(10, 5, 3))
204+
data_missing = data_3d.copy()
205+
indices = [
206+
(i, j, k) for i in range(data_3d.shape[0]) for j in range(data_3d.shape[1]) for k in range(data_3d.shape[2])
207+
]
208+
rng.shuffle(indices)
209+
for i, j, k in indices[:20]:
210+
data_missing[i, j, k] = np.nan
211+
212+
data_flatten = data_missing.copy()
213+
data_per_var = data_missing.copy()
214+
215+
FaissImputer(n_neighbors=5, temporal_mode="flatten").fit_transform(data_flatten)
216+
FaissImputer(n_neighbors=5, temporal_mode="per_variable").fit_transform(data_per_var)
217+
218+
assert not np.array_equal(data_flatten, data_per_var)
219+
220+
221+
def test_invalid_temporal_mode():
222+
"""Tests if imputer raises error for invalid temporal_mode"""
223+
with pytest.raises(ValueError):
224+
FaissImputer(temporal_mode="invalid")
225+
226+
168227
def _are_ndarrays_equal(arr1: np.ndarray, arr2: np.ndarray) -> np.bool_:
169228
"""Check if two arrays are equal member-wise.
170229

0 commit comments

Comments
 (0)