Skip to content

Commit 8750e79

Browse files
Zethsonscverse-bot
andauthored
Add faiss-gpu (#22)
Signed-off-by: Lukas Heumos <lukas.heumos@posteo.net> Co-authored-by: scverse-bot <108668866+scverse-bot@users.noreply.github.com>
1 parent 7e9887a commit 8750e79

File tree

10 files changed

+100
-71
lines changed

10 files changed

+100
-71
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
outputs:
2727
envs: ${{ steps.get-envs.outputs.envs }}
2828
steps:
29-
- uses: actions/checkout@v4
29+
- uses: actions/checkout@v5
3030
with:
3131
filter: blob:none
3232
fetch-depth: 0

.github/workflows/test-gpu.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,15 +42,15 @@ jobs:
4242
- name: Install Python
4343
uses: actions/setup-python@v5
4444
with:
45-
python-version: "3.13"
45+
python-version: "3.12"
4646

4747
- name: Install uv
4848
uses: astral-sh/setup-uv@v7
4949
with:
5050
cache-dependency-glob: pyproject.toml
5151

5252
- name: Install fknni
53-
run: uv pip install --system -e ".[test]"
53+
run: uv pip install --system -e ".[test,faissgpu]"
5454
- name: Pip list
5555
run: pip list
5656

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ classifiers = [
2121
"Programming Language :: Python :: 3.13",
2222
]
2323
dependencies = [
24-
"faiss-cpu",
2524
"lamin-utils",
2625
"pandas",
2726
"scikit-learn",
@@ -45,6 +44,8 @@ optional-dependencies.doc = [
4544
"sphinxcontrib-bibtex>=1",
4645
"sphinxext-opengraph",
4746
]
47+
optional-dependencies.faisscpu = [ "faiss-cpu" ]
48+
optional-dependencies.faissgpu = [ "faiss-gpu-cu12" ]
4849
optional-dependencies.rapids12 = [
4950
"cudf-cu12>=25.10",
5051
"cugraph-cu12>=25.10",
@@ -88,7 +89,7 @@ deps = [ "pre" ]
8889
python = [ "3.13" ]
8990

9091
[tool.hatch.envs.hatch-test]
91-
features = [ "dev", "test" ]
92+
features = [ "dev", "test", "faisscpu" ]
9293

9394
[tool.hatch.envs.hatch-test.overrides]
9495
# If the matrix variable `deps` is set to "pre",
@@ -142,6 +143,7 @@ testpaths = [ "tests" ]
142143
xfail_strict = true
143144
addopts = [
144145
"--import-mode=importlib", # allow using test files with same name
146+
"-m not gpu",
145147
]
146148
markers = [
147149
"gpu: mark test to run on GPU",

src/fknni/faiss/faiss.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
from numpy import dtype
1010
from sklearn.base import BaseEstimator, TransformerMixin
1111

12+
try:
13+
import faiss
14+
15+
HAS_FAISS_GPU = hasattr(faiss, "StandardGpuResources")
16+
except ImportError:
17+
raise ImportError("faiss-cpu or faiss-gpu required") from None
18+
1219

1320
class FaissImputer(BaseEstimator, TransformerMixin):
1421
"""Imputer for completing missing values using Faiss, incorporating weighted averages based on distance."""
@@ -23,6 +30,7 @@ def __init__(
2330
index_factory: str = "Flat",
2431
min_data_ratio: float = 0.25,
2532
temporal_mode: Literal["flatten", "per_variable"] = "flatten",
33+
use_gpu: bool = False,
2634
):
2735
"""Initializes FaissImputer with specified parameters that are used for the imputation.
2836
@@ -39,6 +47,7 @@ def __init__(
3947
temporal_mode: How to handle 3D temporal data. 'flatten' treats all (variable, timestep) pairs as
4048
independent features (fast but allows temporal leakage).
4149
'per_variable' imputes each variable independently across time (slower but respects temporal causality).
50+
use_gpu: Whether to train using GPU.
4251
"""
4352
if n_neighbors < 1:
4453
raise ValueError("n_neighbors must be at least 1.")
@@ -47,6 +56,10 @@ def __init__(
4756
if temporal_mode not in {"flatten", "per_variable"}:
4857
raise ValueError("Unknown temporal_mode. Choose one of 'flatten', 'per_variable'")
4958

59+
self.use_gpu = use_gpu
60+
if use_gpu and not HAS_FAISS_GPU:
61+
raise ValueError("use_gpu=True requires faiss-gpu package, install with: pip install faiss-gpu") from None
62+
5063
self.missing_values = missing_values
5164
self.n_neighbors = n_neighbors
5265
self.metric = metric
@@ -236,6 +249,11 @@ def _features_indices_sorted_descending_on_nan(self) -> list[int]:
236249
def _train(self, x_train: np.ndarray) -> faiss.Index:
237250
index = faiss.index_factory(x_train.shape[1], self.index_factory)
238251
index.metric_type = faiss.METRIC_L2 if self.metric == "l2" else faiss.METRIC_INNER_PRODUCT
252+
253+
if self.use_gpu:
254+
res = faiss.StandardGpuResources()
255+
index = faiss.index_cpu_to_gpu(res, 0, index)
256+
239257
index.train(x_train)
240258
index.add(x_train)
241259
return index

tests/__init__.py

Whitespace-only changes.

tests/compare_predictions.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import numpy as np
2+
3+
4+
def _are_ndarrays_equal(arr1: np.ndarray, arr2: np.ndarray) -> np.bool_:
5+
"""Check if two arrays are equal member-wise.
6+
7+
Note: Two NaN are considered equal.
8+
9+
Args:
10+
arr1: First array to compare
11+
arr2: Second array to compare
12+
13+
Returns:
14+
True if the two arrays are equal member-wise
15+
"""
16+
return np.all(np.equal(arr1, arr2, dtype=object) | ((arr1 != arr1) & (arr2 != arr2)))
17+
18+
19+
def _base_check_imputation(
20+
data_original: np.ndarray,
21+
data_imputed: np.ndarray,
22+
):
23+
"""Provides the following base checks:
24+
- Imputation doesn't leave any NaN behind
25+
- Imputation doesn't modify any data that wasn't NaN
26+
27+
Args:
28+
data_before_imputation: Dataset before imputation
29+
data_after_imputation: Dataset after imputation
30+
31+
Raises:
32+
AssertionError: If any of the checks fail.
33+
"""
34+
if data_original.shape != data_imputed.shape:
35+
raise AssertionError("The shapes of the two datasets do not match")
36+
37+
# Ensure no NaN remains in the imputed dataset
38+
if np.isnan(data_imputed).any():
39+
raise AssertionError("NaN found in imputed columns of layer_after.")
40+
41+
# Ensure imputation does not alter non-NaN values in the imputed columns
42+
imputed_non_nan_mask = ~np.isnan(data_original)
43+
if not _are_ndarrays_equal(data_original[imputed_non_nan_mask], data_imputed[imputed_non_nan_mask]):
44+
raise AssertionError("Non-NaN values in imputed columns were modified.")
45+
46+
return

tests/conftest.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import numpy as np
2+
import pandas as pd
3+
import pytest
4+
5+
6+
@pytest.fixture
7+
def rng():
8+
return np.random.default_rng(0)
9+
10+
11+
@pytest.fixture
12+
def simple_test_df(rng):
13+
data = pd.DataFrame(rng.integers(0, 100, size=(10, 5)), columns=list("ABCDE"))
14+
data_missing = data.copy()
15+
indices = [(i, j) for i in range(data.shape[0]) for j in range(data.shape[1])]
16+
rng.shuffle(indices)
17+
for i, j in indices[:5]:
18+
data_missing.iat[i, j] = np.nan
19+
return data.to_numpy(), data_missing.to_numpy()

tests/cpu/conftest.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

tests/cpu/test_faiss_imputation.py

Lines changed: 1 addition & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,11 @@
11
import numpy as np
2-
import pandas as pd
32
import pytest
43
from sklearn.datasets import make_regression
4+
from tests.compare_predictions import _base_check_imputation
55

66
from fknni.faiss.faiss import FaissImputer
77

88

9-
@pytest.fixture
10-
def simple_test_df(rng):
11-
data = pd.DataFrame(rng.integers(0, 100, size=(10, 5)), columns=list("ABCDE"))
12-
data_missing = data.copy()
13-
indices = [(i, j) for i in range(data.shape[0]) for j in range(data.shape[1])]
14-
rng.shuffle(indices)
15-
for i, j in indices[:5]:
16-
data_missing.iat[i, j] = np.nan
17-
return data.to_numpy(), data_missing.to_numpy()
18-
19-
209
@pytest.fixture
2110
def regression_dataset(rng):
2211
X, y = make_regression(n_samples=100, n_features=20, random_state=42)
@@ -28,36 +17,6 @@ def regression_dataset(rng):
2817
return X, X_missing, y
2918

3019

31-
def _base_check_imputation(
32-
data_original: np.ndarray,
33-
data_imputed: np.ndarray,
34-
):
35-
"""Provides the following base checks:
36-
- Imputation doesn't leave any NaN behind
37-
- Imputation doesn't modify any data that wasn't NaN
38-
39-
Args:
40-
data_before_imputation: Dataset before imputation
41-
data_after_imputation: Dataset after imputation
42-
43-
Raises:
44-
AssertionError: If any of the checks fail.
45-
"""
46-
if data_original.shape != data_imputed.shape:
47-
raise AssertionError("The shapes of the two datasets do not match")
48-
49-
# Ensure no NaN remains in the imputed dataset
50-
if np.isnan(data_imputed).any():
51-
raise AssertionError("NaN found in imputed columns of layer_after.")
52-
53-
# Ensure imputation does not alter non-NaN values in the imputed columns
54-
imputed_non_nan_mask = ~np.isnan(data_original)
55-
if not _are_ndarrays_equal(data_original[imputed_non_nan_mask], data_imputed[imputed_non_nan_mask]):
56-
raise AssertionError("Non-NaN values in imputed columns were modified.")
57-
58-
return
59-
60-
6120
def test_median_imputation(simple_test_df):
6221
"""Tests if median imputation successfully fills all NaN values"""
6322
data, data_missing = simple_test_df
@@ -222,18 +181,3 @@ def test_invalid_temporal_mode():
222181
"""Tests if imputer raises error for invalid temporal_mode"""
223182
with pytest.raises(ValueError):
224183
FaissImputer(temporal_mode="invalid")
225-
226-
227-
def _are_ndarrays_equal(arr1: np.ndarray, arr2: np.ndarray) -> np.bool_:
228-
"""Check if two arrays are equal member-wise.
229-
230-
Note: Two NaN are considered equal.
231-
232-
Args:
233-
arr1: First array to compare
234-
arr2: Second array to compare
235-
236-
Returns:
237-
True if the two arrays are equal member-wise
238-
"""
239-
return np.all(np.equal(arr1, arr2, dtype=object) | ((arr1 != arr1) & (arr2 != arr2)))

tests/gpu/test_gpu.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
import pytest
2+
from tests.compare_predictions import _base_check_imputation
3+
4+
from fknni.faiss.faiss import FaissImputer
25

36

47
@pytest.mark.gpu
5-
def test_gpu():
6-
assert 1 + 1 == 2
8+
def test_median_imputation(simple_test_df):
9+
"""Tests if median imputation successfully fills all NaN values"""
10+
data, data_missing = simple_test_df
11+
data_original = data_missing.copy()
12+
FaissImputer(n_neighbors=5, strategy="median", use_gpu=True).fit_transform(data_missing)
13+
_base_check_imputation(data_original, data_missing)

0 commit comments

Comments
 (0)