Skip to content

Commit bc50857

Browse files
authored
Add RobustScaler (#314)
1 parent c11afad commit bc50857

File tree

3 files changed

+278
-0
lines changed

3 files changed

+278
-0
lines changed

lib/scholar/options.ex

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,4 +108,17 @@ defmodule Scholar.Options do
108108
{:error, "expected 'beta' to be in the range [0, inf]"}
109109
end
110110
end
111+
112+
def quantile_range(value) do
113+
case value do
114+
{q_min, q_max}
115+
when is_number(q_min) and is_number(q_max) and 0.0 < q_min and q_min < q_max and
116+
q_max < 100.0 ->
117+
{:ok, {q_min, q_max}}
118+
119+
_ ->
120+
{:error,
121+
"expected :quantile_range to be a tuple {q_min, q_max} such that 0.0 < q_min < q_max < 100.0, got: #{inspect(value)}"}
122+
end
123+
end
111124
end
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
defmodule Scholar.Preprocessing.RobustScaler do
2+
@moduledoc ~S"""
3+
Scale features using statistics that are robust to outliers.
4+
5+
This Scaler removes the median and scales the data according to
6+
the quantile range (defaults to IQR: Interquartile Range).
7+
The IQR is the range between the 1st quartile (25th quantile)
8+
and the 3rd quartile (75th quantile).
9+
"""
10+
11+
import Nx.Defn
12+
13+
@derive {Nx.Container, containers: [:medians, :iqr]}
14+
defstruct [:medians, :iqr]
15+
16+
opts_schema = [
17+
quantile_range: [
18+
type: {:custom, Scholar.Options, :quantile_range, []},
19+
default: {25.0, 75.0},
20+
doc: """
21+
Quantile range as a tuple {q_min, q_max} defining the range of quantiles
22+
to include. Must satisfy 0.0 < q_min < q_max < 100.0.
23+
"""
24+
]
25+
]
26+
27+
@opts_schema NimbleOptions.new!(opts_schema)
28+
29+
@doc """
30+
Compute the median and quantiles to be used for scaling.
31+
32+
## Options
33+
34+
#{NimbleOptions.docs(@opts_schema)}
35+
36+
## Return values
37+
38+
Returns a struct with the following parameters:
39+
40+
* `:iqr` - the calculated interquartile range.
41+
42+
* `:medians` - the calculated medians of each feature across samples.
43+
44+
## Examples
45+
46+
iex> Scholar.Preprocessing.RobustScaler.fit(Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]]))
47+
%Scholar.Preprocessing.RobustScaler{
48+
medians: Nx.tensor([1, 0, 0]),
49+
iqr: Nx.tensor([1.0, 1.0, 1.5])
50+
}
51+
"""
52+
deftransform fit(tensor, opts \\ []) do
53+
fit_n(tensor, NimbleOptions.validate!(opts, @opts_schema))
54+
end
55+
56+
defnp fit_n(tensor, opts) do
57+
check_for_rank(tensor)
58+
59+
{q_min, q_max} = opts[:quantile_range]
60+
61+
medians = Nx.median(tensor, axis: 0)
62+
63+
sorted_tensor = Nx.sort(tensor, axis: 0)
64+
65+
q_min = percentile(sorted_tensor, q_min)
66+
q_max = percentile(sorted_tensor, q_max)
67+
68+
iqr = q_max - q_min
69+
70+
%__MODULE__{medians: medians, iqr: iqr}
71+
end
72+
73+
@doc """
74+
Performs centering and scaling of the tensor using a fitted scaler.
75+
76+
## Examples
77+
78+
iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
79+
iex> scaler = Scholar.Preprocessing.RobustScaler.fit(t)
80+
%Scholar.Preprocessing.RobustScaler{
81+
medians: Nx.tensor([1, 0, 0]),
82+
iqr: Nx.tensor([1.0, 1.0, 1.5])
83+
}
84+
iex> Scholar.Preprocessing.RobustScaler.transform(scaler, t)
85+
#Nx.Tensor<
86+
f32[3][3]
87+
[
88+
[0.0, -1.0, 1.3333333730697632],
89+
[1.0, 0.0, 0.0],
90+
[-1.0, 1.0, -0.6666666865348816]
91+
]
92+
>
93+
"""
94+
defn transform(%__MODULE__{medians: medians, iqr: iqr}, tensor) do
95+
check_for_rank(tensor)
96+
scale(tensor, medians, iqr)
97+
end
98+
99+
@doc """
100+
Computes the scaling parameters and applies them to transform the tensor.
101+
102+
## Examples
103+
104+
iex> t = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
105+
iex> Scholar.Preprocessing.RobustScaler.fit_transform(t)
106+
#Nx.Tensor<
107+
f32[3][3]
108+
[
109+
[0.0, -1.0, 1.3333333730697632],
110+
[1.0, 0.0, 0.0],
111+
[-1.0, 1.0, -0.6666666865348816]
112+
]
113+
>
114+
"""
115+
defn fit_transform(tensor, opts \\ []) do
116+
tensor
117+
|> fit(opts)
118+
|> transform(tensor)
119+
end
120+
121+
defnp scale(tensor, medians, iqr) do
122+
(tensor - medians) / Nx.select(iqr == 0, 1.0, iqr)
123+
end
124+
125+
defnp percentile(sorted_tensor, p) do
126+
num_rows = Nx.axis_size(sorted_tensor, 0)
127+
idx = p / 100 * (num_rows - 1)
128+
129+
lower_idx = Nx.floor(idx) |> Nx.as_type(:s64)
130+
upper_idx = Nx.ceil(idx) |> Nx.as_type(:s64)
131+
132+
lower_values = Nx.take(sorted_tensor, lower_idx, axis: 0)
133+
upper_values = Nx.take(sorted_tensor, upper_idx, axis: 0)
134+
135+
weight_upper = idx - Nx.floor(idx)
136+
weight_lower = 1.0 - weight_upper
137+
lower_values * weight_lower + upper_values * weight_upper
138+
end
139+
140+
defnp check_for_rank(tensor) do
141+
if Nx.rank(tensor) != 2 do
142+
raise ArgumentError,
143+
"""
144+
expected tensor to have shape {num_samples, num_features}, \
145+
got tensor with shape: #{inspect(Nx.shape(tensor))}\
146+
"""
147+
end
148+
end
149+
end
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
defmodule Scholar.Preprocessing.RobustScalerTest do
2+
use Scholar.Case, async: true
3+
alias Scholar.Preprocessing.RobustScaler
4+
doctest RobustScaler
5+
6+
describe "fit_transform" do
7+
test "applies scaling to data" do
8+
data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
9+
10+
expected =
11+
Nx.tensor([
12+
[0.0, -1.0, 1.3333333333333333],
13+
[1.0, 0.0, 0.0],
14+
[-1.0, 1.0, -0.6666666666666666]
15+
])
16+
17+
assert_all_close(RobustScaler.fit_transform(data), expected)
18+
end
19+
20+
test "applies scaling to data with custom quantile range" do
21+
data = Nx.tensor([[1, -1, 2], [2, 0, 0], [0, 1, -1]])
22+
23+
expected =
24+
Nx.tensor([
25+
[0.0, -0.7142857142857142, 1.0],
26+
[0.7142857142857142, 0.0, 0.0],
27+
[-0.7142857142857142, 0.7142857142857142, -0.5]
28+
])
29+
30+
assert_all_close(
31+
RobustScaler.fit_transform(data, quantile_range: {10, 80}),
32+
expected
33+
)
34+
end
35+
36+
test "handles constant data (all values the same)" do
37+
data = Nx.tensor([[5, 5, 5], [5, 5, 5], [5, 5, 5]])
38+
expected = Nx.tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]])
39+
40+
assert_all_close(RobustScaler.fit_transform(data), expected)
41+
end
42+
43+
test "handles already scaled data" do
44+
data = Nx.tensor([[0, -1, 1], [1, 0, 0], [-1, 1, -1]])
45+
expected = data
46+
47+
assert_all_close(RobustScaler.fit_transform(data), expected)
48+
end
49+
50+
test "handles single-row tensor" do
51+
data = Nx.tensor([[1, 2, 3]])
52+
expected = Nx.tensor([[0.0, 0.0, 0.0]])
53+
54+
assert_all_close(RobustScaler.fit_transform(data), expected)
55+
end
56+
57+
test "handles single-column tensor" do
58+
data = Nx.tensor([[1], [2], [3]])
59+
expected = Nx.tensor([[-1.0], [0.0], [1.0]])
60+
61+
assert_all_close(RobustScaler.fit_transform(data), expected)
62+
end
63+
64+
test "handles data with negative values only" do
65+
data = Nx.tensor([[-5, -10, -15], [-15, -5, -20], [-10, -15, -5]])
66+
67+
expected =
68+
Nx.tensor([
69+
[1.0, 0.0, 0.0],
70+
[-1.0, 1.0, -0.6666666666666666],
71+
[0.0, -1.0, 1.3333333333333333]
72+
])
73+
74+
assert_all_close(RobustScaler.fit_transform(data), expected)
75+
end
76+
77+
test "handles data with extreme outliers" do
78+
data = Nx.tensor([[1, 2, 3], [1000, 2000, 3000], [-1000, -2000, -3000]])
79+
80+
expected =
81+
Nx.tensor([[0.0, 0.0, 0.0], [0.999, 0.999, 0.999], [-1.001, -1.001, -1.001]])
82+
83+
assert_all_close(
84+
RobustScaler.fit_transform(data),
85+
expected
86+
)
87+
end
88+
end
89+
90+
describe "errors" do
91+
test "wrong input rank for fit" do
92+
assert_raise ArgumentError,
93+
"expected tensor to have shape {num_samples, num_features}, got tensor with shape: {1, 1, 1}",
94+
fn ->
95+
RobustScaler.fit(Nx.tensor([[[1]]]))
96+
end
97+
end
98+
99+
test "wrong input rank for transform" do
100+
assert_raise ArgumentError,
101+
"expected tensor to have shape {num_samples, num_features}, got tensor with shape: {1, 1, 1}",
102+
fn ->
103+
RobustScaler.fit(Nx.tensor([[1]]))
104+
|> RobustScaler.transform(Nx.tensor([[[1]]]))
105+
end
106+
end
107+
108+
test "wrong quantile range" do
109+
assert_raise NimbleOptions.ValidationError,
110+
"invalid value for :quantile_range option: expected :quantile_range to be a tuple {q_min, q_max} such that 0.0 < q_min < q_max < 100.0, got: {10, 800}",
111+
fn ->
112+
RobustScaler.fit(Nx.tensor([[[1]]]), quantile_range: {10, 800})
113+
end
114+
end
115+
end
116+
end

0 commit comments

Comments
 (0)