Skip to content

Commit c11afad

Browse files
authored
Add KNNImputer (#303)
1 parent 4de37f1 commit c11afad

File tree

2 files changed

+376
-0
lines changed

2 files changed

+376
-0
lines changed

lib/scholar/impute/knn_imputter.ex

Lines changed: 247 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
defmodule Scholar.Impute.KNNImputter do
2+
@moduledoc """
3+
Imputer for completing missing values using k-Nearest Neighbors.
4+
5+
Each sample's missing values are imputed using the mean value from
6+
`n_neighbors` nearest neighbors found in the training set. Two samples are
7+
close if the features that neither is missing are close.
8+
"""
9+
import Nx.Defn
10+
import Scholar.Metrics.Distance
11+
12+
@derive {Nx.Container, keep: [:missing_values], containers: [:statistics]}
13+
defstruct [:statistics, :missing_values]
14+
15+
opts_schema = [
16+
missing_values: [
17+
type: {:or, [:float, :integer, {:in, [:infinity, :neg_infinity, :nan]}]},
18+
default: :nan,
19+
doc: ~S"""
20+
The placeholder for the missing values. All occurrences of `:missing_values` will be imputed.
21+
22+
The default value expects there are no NaNs in the input tensor.
23+
"""
24+
],
25+
num_neighbors: [
26+
type: :pos_integer,
27+
default: 2,
28+
doc: "The number of nearest neighbors."
29+
]
30+
]
31+
32+
@opts_schema NimbleOptions.new!(opts_schema)
33+
34+
@doc """
35+
Imputter for completing missing values using k-Nearest Neighbors.
36+
37+
Preconditions:
38+
* The number of neighbors must be less than the number of valid rows - 1.
39+
* A valid row is a row with more than 1 non-NaN values. Otherwise it is better to use a simpler imputter.
40+
* When you set a value different than :nan in `missing_values` there should be no NaNs in the input tensor
41+
42+
## Options
43+
44+
#{NimbleOptions.docs(@opts_schema)}
45+
46+
## Return Values
47+
48+
The function returns a struct with the following parameters:
49+
50+
* `:missing_values` - the same value as in the `:missing_values` option
51+
52+
* `:statistics` - The imputation fill value for each feature. Computing statistics can result in values.
53+
54+
## Examples
55+
56+
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]])
57+
iex> Scholar.Impute.KNNImputter.fit(x, num_neighbors: 2)
58+
%Scholar.Impute.KNNImputter{
59+
statistics: Nx.tensor(
60+
[
61+
[:nan, :nan],
62+
[:nan, :nan],
63+
[:nan, 8.0],
64+
[7.5, :nan],
65+
[:nan, :nan]
66+
]
67+
),
68+
missing_values: :nan
69+
}
70+
71+
"""
72+
73+
deftransform fit(x, opts \\ []) do
74+
opts = NimbleOptions.validate!(opts, @opts_schema)
75+
76+
input_rank = Nx.rank(x)
77+
78+
if input_rank != 2 do
79+
raise ArgumentError, "wrong input rank. Expected: 2, got: #{inspect(input_rank)}"
80+
end
81+
82+
missing_values = opts[:missing_values]
83+
84+
x =
85+
if missing_values != :nan,
86+
do: Nx.select(Nx.equal(x, missing_values), :nan, x),
87+
else: x
88+
89+
statistics =
90+
knn_impute(x, num_neighbors: opts[:num_neighbors], missing_values: missing_values)
91+
92+
%__MODULE__{statistics: statistics, missing_values: missing_values}
93+
end
94+
95+
@doc """
96+
Impute all missing values in `x` using fitted imputer.
97+
98+
## Return Values
99+
100+
The function returns input tensor with NaN replaced with values saved in fitted imputer.
101+
102+
## Examples
103+
104+
iex> x = Nx.tensor([[40.0, 2.0],[4.0, 5.0],[7.0, :nan],[:nan, 8.0],[11.0, 11.0]])
105+
iex> imputer = Scholar.Impute.KNNImputter.fit(x, num_neighbors: 2)
106+
iex> Scholar.Impute.KNNImputter.transform(imputer, x)
107+
Nx.tensor(
108+
[
109+
[40.0, 2.0],
110+
[4.0, 5.0],
111+
[7.0, 8.0],
112+
[7.5, 8.0],
113+
[11.0, 11.0]
114+
]
115+
)
116+
"""
117+
deftransform transform(%__MODULE__{statistics: statistics, missing_values: missing_values}, x) do
118+
mask = if missing_values == :nan, do: Nx.is_nan(x), else: Nx.equal(x, missing_values)
119+
Nx.select(mask, statistics, x)
120+
end
121+
122+
defnp knn_impute(x, opts \\ []) do
123+
mask = Nx.is_nan(x)
124+
{num_rows, num_cols} = Nx.shape(x)
125+
num_neighbors = opts[:num_neighbors]
126+
127+
placeholder_value = Nx.tensor(:nan)
128+
129+
values_to_impute = Nx.broadcast(placeholder_value, x)
130+
131+
{_, values_to_impute} =
132+
while {{row = 0, mask, num_neighbors, num_rows, x}, values_to_impute},
133+
row < num_rows do
134+
{_, values_to_impute} =
135+
while {{col = 0, mask, num_neighbors, num_cols, row, x}, values_to_impute},
136+
col < num_cols do
137+
if mask[row][col] do
138+
{rows, cols} = Nx.shape(x)
139+
140+
neighbor_avg =
141+
calculate_knn(x, row, col, rows: rows, num_neighbors: opts[:num_neighbors])
142+
143+
values_to_impute =
144+
Nx.put_slice(values_to_impute, [row, col], Nx.reshape(neighbor_avg, {1, 1}))
145+
146+
{{col + 1, mask, num_neighbors, cols, row, x}, values_to_impute}
147+
else
148+
{{col + 1, mask, num_neighbors, num_cols, row, x}, values_to_impute}
149+
end
150+
end
151+
152+
{{row + 1, mask, num_neighbors, num_rows, x}, values_to_impute}
153+
end
154+
155+
values_to_impute
156+
end
157+
158+
defnp calculate_knn(x, nan_row, nan_col, opts \\ []) do
159+
opts = keyword!(opts, rows: 1, num_neighbors: 2)
160+
rows = opts[:rows]
161+
num_neighbors = opts[:num_neighbors]
162+
163+
row_distances = Nx.iota({rows}, type: {:f, 32})
164+
165+
row_with_value_to_fill = x[nan_row]
166+
167+
# calculate distance between row with nan to fill and all other rows where distance
168+
# to the row is under its index in the tensor
169+
{_, row_distances} =
170+
while {{i = 0, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances},
171+
i < rows do
172+
potential_donor = x[i]
173+
174+
distance =
175+
calculate_distance(row_with_value_to_fill, nan_col, potential_donor, nan_row)
176+
177+
row_distances = Nx.indexed_put(row_distances, Nx.new_axis(i, 0), distance)
178+
{{i + 1, x, row_with_value_to_fill, rows, nan_row, nan_col}, row_distances}
179+
end
180+
181+
{_, indices} = Nx.top_k(-row_distances, k: num_neighbors)
182+
183+
gather_indices = Nx.stack([indices, Nx.broadcast(nan_col, indices)], axis: 1)
184+
values = Nx.gather(x, gather_indices)
185+
Nx.sum(values) / num_neighbors
186+
end
187+
188+
defnp calculate_distance(row, nan_col, potential_donor, nan_row) do
189+
case row do
190+
^nan_row -> Nx.Constants.infinity(Nx.type(row))
191+
_ -> nan_euclidean(row, nan_col, potential_donor)
192+
end
193+
end
194+
195+
# nan_col is the column of the value to impute
196+
defnp nan_euclidean(row, nan_col, potential_neighbor) do
197+
{coordinates} = Nx.shape(row)
198+
199+
# minus nan column
200+
coordinates = coordinates - 1
201+
202+
# inputes zeros in nan_col to calculate distance with squared_euclidean
203+
new_row = Nx.indexed_put(row, Nx.new_axis(nan_col, 0), Nx.tensor(0))
204+
205+
# if potential neighbor has nan in nan_col, we don't want to calculate distance and the case if potential_neighbour is the row to impute
206+
{potential_neighbor} =
207+
if Nx.is_nan(potential_neighbor[nan_col]) do
208+
potential_neighbor =
209+
Nx.broadcast(Nx.Constants.infinity(Nx.type(potential_neighbor)), potential_neighbor)
210+
211+
{potential_neighbor}
212+
else
213+
# inputes zeros in nan_col to calculate distance with squared_euclidean - distance will be 0 so no change to the distance value
214+
potential_neighbor =
215+
Nx.indexed_put(
216+
potential_neighbor,
217+
Nx.new_axis(nan_col, 0),
218+
Nx.tensor(0, type: Nx.type(row))
219+
)
220+
221+
{potential_neighbor}
222+
end
223+
224+
# calculates how many values are present in the row without nan_col to calculate weight for the distance
225+
present_coordinates = Nx.sum(Nx.logical_not(Nx.is_nan(potential_neighbor))) - 1
226+
227+
# if row has all nans we skip it
228+
{weight, potential_neighbor} =
229+
if present_coordinates == 0 do
230+
potential_neighbor =
231+
Nx.broadcast(Nx.Constants.infinity(Nx.type(potential_neighbor)), potential_neighbor)
232+
233+
weight = 0
234+
{weight, potential_neighbor}
235+
else
236+
potential_neighbor = Nx.select(Nx.is_nan(potential_neighbor), new_row, potential_neighbor)
237+
weight = coordinates / present_coordinates
238+
{weight, potential_neighbor}
239+
end
240+
241+
# calculating weighted euclidian distance
242+
distance = Nx.sqrt(weight * squared_euclidean(new_row, potential_neighbor))
243+
244+
# return inf if potential_row is row to impute
245+
Nx.select(Nx.is_nan(distance), Nx.Constants.infinity(Nx.type(distance)), distance)
246+
end
247+
end
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
defmodule KNNImputterTest do
2+
use Scholar.Case, async: true
3+
alias Scholar.Impute.KNNImputter
4+
doctest KNNImputter
5+
6+
describe "general cases" do
7+
def generate_data() do
8+
x = Nx.iota({5, 4})
9+
x = Nx.select(Nx.equal(Nx.quotient(x, 5), 2), Nx.Constants.nan(), x)
10+
Nx.indexed_put(x, Nx.tensor([[4, 2]]), Nx.tensor([6.0]))
11+
end
12+
13+
test "general KNN imputer" do
14+
x = generate_data()
15+
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2)
16+
jit_transform = Nx.Defn.jit(&KNNImputter.transform/2)
17+
18+
knn_imputer =
19+
%KNNImputter{statistics: statistics, missing_values: missing_values} =
20+
jit_fit.(x, missing_values: :nan, num_neighbors: 2)
21+
22+
assert missing_values == :nan
23+
24+
assert statistics ==
25+
Nx.tensor([
26+
[:nan, :nan, :nan, :nan],
27+
[:nan, :nan, :nan, :nan],
28+
[:nan, :nan, 4.0, 5.0],
29+
[2.0, 3.0, 4.0, :nan],
30+
[:nan, :nan, :nan, :nan]
31+
])
32+
33+
assert jit_transform.(knn_imputer, x) ==
34+
Nx.tensor([
35+
[0.0, 1.0, 2.0, 3.0],
36+
[4.0, 5.0, 6.0, 7.0],
37+
[8.0, 9.0, 4.0, 5.0],
38+
[2.0, 3.0, 4.0, 15.0],
39+
[16.0, 17.0, 6.0, 19.0]
40+
])
41+
end
42+
43+
test "general KNN imputer with different number of neighbors" do
44+
x = generate_data()
45+
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2)
46+
jit_transform = Nx.Defn.jit(&KNNImputter.transform/2)
47+
48+
knn_imputter =
49+
%KNNImputter{statistics: statistics, missing_values: missing_values} =
50+
jit_fit.(x, missing_values: :nan, num_neighbors: 1)
51+
52+
assert missing_values == :nan
53+
54+
assert statistics ==
55+
Nx.tensor([
56+
[:nan, :nan, :nan, :nan],
57+
[:nan, :nan, :nan, :nan],
58+
[:nan, :nan, 2.0, 3.0],
59+
[0.0, 1.0, 2.0, :nan],
60+
[:nan, :nan, :nan, :nan]
61+
])
62+
63+
assert jit_transform.(knn_imputter, x) ==
64+
Nx.tensor([
65+
[0.0, 1.0, 2.0, 3.0],
66+
[4.0, 5.0, 6.0, 7.0],
67+
[8.0, 9.0, 2.0, 3.0],
68+
[0.0, 1.0, 2.0, 15.0],
69+
[16.0, 17.0, 6.0, 19.0]
70+
])
71+
end
72+
73+
test "missing values different than :nan" do
74+
x = generate_data()
75+
x = Nx.select(Nx.is_nan(x), 19.0, x)
76+
# x = Nx.select(Nx.equal(x,19), :nan, x)
77+
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2)
78+
jit_transform = Nx.Defn.jit(&KNNImputter.transform/2)
79+
80+
knn_imputter =
81+
%KNNImputter{statistics: statistics, missing_values: missing_values} =
82+
jit_fit.(x, missing_values: 19.0, num_neighbors: 2)
83+
84+
assert missing_values == 19.0
85+
86+
assert statistics ==
87+
Nx.tensor([
88+
[:nan, :nan, :nan, :nan],
89+
[:nan, :nan, :nan, :nan],
90+
[:nan, :nan, 4.0, 5.0],
91+
[2.0, 3.0, 4.0, :nan],
92+
[:nan, :nan, :nan, 5.0]
93+
])
94+
95+
assert jit_transform.(knn_imputter, x) ==
96+
Nx.tensor([
97+
[0.0, 1.0, 2.0, 3.0],
98+
[4.0, 5.0, 6.0, 7.0],
99+
[8.0, 9.0, 4.0, 5.0],
100+
[2.0, 3.0, 4.0, 15.0],
101+
[16.0, 17.0, 6.0, 5.0]
102+
])
103+
end
104+
end
105+
106+
describe "errors" do
107+
test "invalid impute rank" do
108+
x = Nx.tensor([1, 2, 2, 3])
109+
110+
assert_raise ArgumentError,
111+
"wrong input rank. Expected: 2, got: 1",
112+
fn ->
113+
KNNImputter.fit(x, missing_values: 1, num_neighbors: 2)
114+
end
115+
end
116+
117+
test "invalid n_neighbors value" do
118+
x = generate_data()
119+
120+
jit_fit = Nx.Defn.jit(&KNNImputter.fit/2)
121+
122+
assert_raise NimbleOptions.ValidationError,
123+
"invalid value for :num_neighbors option: expected positive integer, got: -1",
124+
fn ->
125+
jit_fit.(x, missing_values: 1.0, num_neighbors: -1)
126+
end
127+
end
128+
end
129+
end

0 commit comments

Comments
 (0)