Skip to content

Commit 76de8c3

Browse files
authored
Logistic regression fix (#329)
* Add average when computing the loss function * Remove optimizer and compute gradients directly * Add L2 regularization * Add more unit-tests
1 parent 7c69c26 commit 76de8c3

File tree

3 files changed

+193
-158
lines changed

3 files changed

+193
-158
lines changed
Lines changed: 154 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
defmodule Scholar.Linear.LogisticRegression do
22
@moduledoc """
3-
Logistic regression in both binary and multinomial variants.
3+
Multiclass logistic regression.
44
55
Time complexity is $O(N * K * I)$ where $N$ is the number of samples, $K$ is the number of features, and $I$ is the number of iterations.
66
"""
77
import Nx.Defn
88
import Scholar.Shared
9-
alias Scholar.Linear.LinearHelpers
109

1110
@derive {Nx.Container, containers: [:coefficients, :bias]}
1211
defstruct [:coefficients, :bias]
@@ -15,35 +14,28 @@ defmodule Scholar.Linear.LogisticRegression do
1514
num_classes: [
1615
required: true,
1716
type: :pos_integer,
18-
doc: "number of classes contained in the input tensors."
17+
doc: "Number of output classes."
1918
],
20-
iterations: [
19+
max_iterations: [
2120
type: :pos_integer,
2221
default: 1000,
23-
doc: """
24-
number of iterations of gradient descent performed inside logistic
25-
regression.
26-
"""
22+
doc: "Maximum number of gradient descent iterations to perform."
2723
],
28-
learning_loop_unroll: [
29-
type: :boolean,
30-
default: false,
31-
doc: ~S"""
32-
If `true`, the learning loop is unrolled.
24+
alpha: [
25+
type: {:custom, Scholar.Options, :non_negative_number, []},
26+
default: 1.0,
27+
doc: """
28+
Constant that multiplies the L2 regularization term, controlling regularization strength.
29+
If 0, no regularization is applied.
3330
"""
3431
],
35-
optimizer: [
36-
type: {:custom, Scholar.Options, :optimizer, []},
37-
default: :sgd,
32+
tol: [
33+
type: {:custom, Scholar.Options, :non_negative_number, []},
34+
default: 1.0e-4,
3835
doc: """
39-
The optimizer name or {init, update} pair of functions (see `Polaris.Optimizers` for more details).
36+
Convergence tolerance. If the infinity norm of the gradient is less than `:tol`,
37+
the algorithm is considered to have converged.
4038
"""
41-
],
42-
eps: [
43-
type: :float,
44-
default: 1.0e-8,
45-
doc:
46-
"The convergence tolerance. If the `abs(loss) < size(x) * :eps`, the algorithm is considered to have converged."
4739
]
4840
]
4941

@@ -53,9 +45,6 @@ defmodule Scholar.Linear.LogisticRegression do
5345
Fits a logistic regression model for sample inputs `x` and sample
5446
targets `y`.
5547
56-
Depending on number of classes the function chooses either binary
57-
or multinomial logistic regression.
58-
5948
## Options
6049
6150
#{NimbleOptions.docs(@opts_schema)}
@@ -68,10 +57,6 @@ defmodule Scholar.Linear.LogisticRegression do
6857
6958
* `:bias` - Bias added to the decision function.
7059
71-
* `:mode` - Indicates whether the problem is binary classification (`:num_classes` set to 2)
72-
or multinomial (`:num_classes` is bigger than 2). For binary classification set to `:binary`, otherwise
73-
set to `:multinomial`.
74-
7560
## Examples
7661
7762
iex> x = Nx.tensor([[1.0, 2.0], [3.0, 2.0], [4.0, 7.0]])
@@ -80,134 +65,177 @@ defmodule Scholar.Linear.LogisticRegression do
8065
%Scholar.Linear.LogisticRegression{
8166
coefficients: Nx.tensor(
8267
[
83-
[2.5531527996063232, -0.5531544089317322],
84-
[-0.35652396082878113, 2.3565237522125244]
68+
[0.0915902629494667, -0.09159023314714432],
69+
[-0.1507941037416458, 0.1507941335439682]
8570
]
8671
),
87-
bias: Nx.tensor(
88-
[-0.28847914934158325, 0.28847917914390564]
89-
)
72+
bias: Nx.tensor([-0.06566660106182098, 0.06566664576530457])
9073
}
9174
"""
9275
deftransform fit(x, y, opts \\ []) do
9376
if Nx.rank(x) != 2 do
9477
raise ArgumentError,
95-
"expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
78+
"expected x to have shape {num_samples, num_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
9679
end
9780

98-
{n_samples, _} = Nx.shape(x)
99-
y = LinearHelpers.validate_y_shape(y, n_samples, __MODULE__)
100-
101-
opts = NimbleOptions.validate!(opts, @opts_schema)
102-
103-
{optimizer, opts} = Keyword.pop!(opts, :optimizer)
104-
105-
{optimizer_init_fn, optimizer_update_fn} =
106-
case optimizer do
107-
atom when is_atom(atom) -> apply(Polaris.Optimizers, atom, [])
108-
{f1, f2} -> {f1, f2}
109-
end
81+
if Nx.rank(y) != 1 do
82+
raise ArgumentError,
83+
"expected y to have shape {num_samples}, got tensor with shape: #{inspect(Nx.shape(y))}"
84+
end
11085

111-
n = Nx.axis_size(x, -1)
112-
num_classes = opts[:num_classes]
86+
num_samples = Nx.axis_size(x, 0)
11387

114-
coef =
115-
Nx.broadcast(
116-
Nx.tensor(1.0, type: to_float_type(x)),
117-
{n, num_classes}
118-
)
88+
if Nx.axis_size(y, 0) != num_samples do
89+
raise ArgumentError,
90+
"expected x and y to have the same number of samples, got #{num_samples} and #{Nx.axis_size(y, 0)}"
91+
end
11992

120-
bias = Nx.broadcast(Nx.tensor(0, type: to_float_type(x)), {num_classes})
93+
opts = NimbleOptions.validate!(opts, @opts_schema)
12194

122-
coef_optimizer_state = optimizer_init_fn.(coef) |> as_type(to_float_type(x))
123-
bias_optimizer_state = optimizer_init_fn.(bias) |> as_type(to_float_type(x))
95+
type = to_float_type(x)
12496

125-
opts = Keyword.put(opts, :optimizer_update_fn, optimizer_update_fn)
97+
{alpha, opts} = Keyword.pop!(opts, :alpha)
98+
alpha = Nx.tensor(alpha, type: type)
99+
{tol, opts} = Keyword.pop!(opts, :tol)
100+
tol = Nx.tensor(tol, type: type)
126101

127-
fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts)
102+
fit_n(x, y, alpha, tol, opts)
128103
end
129104

130-
deftransformp as_type(container, target_type) do
131-
Nx.Defn.Composite.traverse(container, fn t ->
132-
type = Nx.type(t)
105+
defnp fit_n(x, y, alpha, tol, opts) do
106+
num_classes = opts[:num_classes]
107+
max_iterations = opts[:max_iterations]
108+
{num_samples, num_features} = Nx.shape(x)
133109

134-
if Nx.Type.float?(type) and not Nx.Type.complex?(type) do
135-
Nx.as_type(t, target_type)
136-
else
137-
t
138-
end
139-
end)
140-
end
110+
type = to_float_type(x)
141111

142-
# Logistic Regression training loop
112+
# Initialize weights and bias with zeros
113+
w =
114+
Nx.broadcast(
115+
Nx.tensor(0.0, type: type),
116+
{num_features, num_classes}
117+
)
143118

144-
defnp fit_n(x, y, coef, bias, coef_optimizer_state, bias_optimizer_state, opts) do
145-
num_samples = Nx.axis_size(x, 0)
146-
iterations = opts[:iterations]
147-
num_classes = opts[:num_classes]
148-
optimizer_update_fn = opts[:optimizer_update_fn]
119+
b = Nx.broadcast(Nx.tensor(0.0, type: type), {num_classes})
149120

121+
# One-hot encoding of target labels
150122
y_one_hot =
151123
y
152124
|> Nx.new_axis(1)
153125
|> Nx.broadcast({num_samples, num_classes})
154126
|> Nx.equal(Nx.iota({num_samples, num_classes}, axis: 1))
155127

156-
{{final_coef, final_bias}, _} =
157-
while {{coef, bias},
158-
{x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state,
159-
has_converged = Nx.u8(0), iter = 0}},
160-
iter < iterations and not has_converged do
161-
{loss, {coef_grad, bias_grad}} = loss_and_grad(coef, bias, x, y_one_hot)
162-
163-
{coef_updates, coef_optimizer_state} =
164-
optimizer_update_fn.(coef_grad, coef_optimizer_state, coef)
165-
166-
coef = Polaris.Updates.apply_updates(coef, coef_updates)
167-
168-
{bias_updates, bias_optimizer_state} =
169-
optimizer_update_fn.(bias_grad, bias_optimizer_state, bias)
128+
# Define Armijo parameters
129+
c = Nx.tensor(1.0e-4, type: type)
130+
rho = Nx.tensor(0.5, type: type)
170131

171-
bias = Polaris.Updates.apply_updates(bias, bias_updates)
132+
eta_min =
133+
case type do
134+
{:f, 32} -> Nx.tensor(1.0e-6, type: type)
135+
{:f, 64} -> Nx.tensor(1.0e-8, type: type)
136+
_ -> Nx.tensor(1.0e-6, type: type)
137+
end
172138

173-
has_converged = Nx.sum(Nx.abs(loss)) < Nx.size(x) * opts[:eps]
139+
armijo_params = %{
140+
c: c,
141+
rho: rho,
142+
eta_min: eta_min
143+
}
174144

175-
{{coef, bias},
176-
{x, iterations, y_one_hot, coef_optimizer_state, bias_optimizer_state, has_converged,
177-
iter + 1}}
145+
{coef, bias, _} =
146+
while {w, b,
147+
{alpha, x, y_one_hot, tol, armijo_params, iter = Nx.u32(0), converged? = Nx.u8(0)}},
148+
iter < max_iterations and not converged? do
149+
logits = Nx.dot(x, w) + b
150+
probabilities = softmax(logits)
151+
residuals = probabilities - y_one_hot
152+
153+
# Compute loss
154+
loss =
155+
logits
156+
|> log_softmax()
157+
|> Nx.multiply(y_one_hot)
158+
|> Nx.sum(axes: [1])
159+
|> Nx.mean()
160+
|> Nx.negate()
161+
|> Nx.add(alpha * Nx.sum(w * w))
162+
163+
# Compute gradients
164+
grad_w = Nx.dot(x, [0], residuals, [0]) / num_samples + 2 * alpha * w
165+
grad_b = Nx.sum(residuals, axes: [0]) / num_samples
166+
167+
# Perform line search to find step size
168+
eta =
169+
armijo_line_search(w, b, alpha, x, y_one_hot, loss, grad_w, grad_b, armijo_params)
170+
171+
w = w - eta * grad_w
172+
b = b - eta * grad_b
173+
174+
converged? =
175+
Nx.reduce_max(Nx.abs(grad_w)) < tol and Nx.reduce_max(Nx.abs(grad_b)) < tol
176+
177+
{w, b, {alpha, x, y_one_hot, tol, armijo_params, iter + 1, converged?}}
178178
end
179179

180180
%__MODULE__{
181-
coefficients: final_coef,
182-
bias: final_bias
181+
coefficients: coef,
182+
bias: bias
183183
}
184184
end
185185

186-
defnp loss_and_grad(coeff, bias, xs, ys) do
187-
value_and_grad({coeff, bias}, fn {coeff, bias} ->
188-
-Nx.sum(ys * log_softmax(Nx.dot(xs, coeff) + bias), axes: [-1])
189-
end)
186+
defnp armijo_line_search(w, b, alpha, x, y, loss, grad_w, grad_b, armijo_params) do
187+
c = armijo_params[:c]
188+
rho = armijo_params[:rho]
189+
eta_min = armijo_params[:eta_min]
190+
191+
type = to_float_type(x)
192+
dir_w = -grad_w
193+
dir_b = -grad_b
194+
# Directional derivative
195+
slope = Nx.sum(dir_w * grad_w) + Nx.sum(dir_b * grad_b)
196+
197+
{eta, _} =
198+
while {eta = Nx.tensor(1.0, type: type),
199+
{w, b, alpha, x, y, loss, dir_w, dir_b, slope, c, rho, eta_min}},
200+
compute_loss(w + eta * dir_w, b + eta * dir_b, alpha, x, y) > loss + c * eta * slope and
201+
eta > eta_min do
202+
eta = eta * rho
203+
204+
{eta, {w, b, alpha, x, y, loss, dir_w, dir_b, slope, c, rho, eta_min}}
205+
end
206+
207+
eta
208+
end
209+
210+
defnp compute_loss(w, b, alpha, x, y) do
211+
x
212+
|> Nx.dot(w)
213+
|> Nx.add(b)
214+
|> log_softmax()
215+
|> Nx.multiply(y)
216+
|> Nx.sum(axes: [1])
217+
|> Nx.mean()
218+
|> Nx.negate()
219+
|> Nx.add(alpha * Nx.sum(w * w))
220+
end
221+
222+
defnp softmax(logits) do
223+
max = stop_grad(Nx.reduce_max(logits, axes: [1], keep_axes: true))
224+
normalized_exp = (logits - max) |> Nx.exp()
225+
normalized_exp / Nx.sum(normalized_exp, axes: [1], keep_axes: true)
190226
end
191227

192228
defnp log_softmax(x) do
193-
shifted = x - stop_grad(Nx.reduce_max(x, axes: [-1], keep_axes: true))
229+
shifted = x - stop_grad(Nx.reduce_max(x, axes: [1], keep_axes: true))
194230

195231
shifted
196232
|> Nx.exp()
197-
|> Nx.sum(axes: [-1], keep_axes: true)
233+
|> Nx.sum(axes: [1], keep_axes: true)
198234
|> Nx.log()
199235
|> Nx.negate()
200236
|> Nx.add(shifted)
201237
end
202238

203-
# Normalized softmax
204-
205-
defnp softmax(t) do
206-
max = stop_grad(Nx.reduce_max(t, axes: [-1], keep_axes: true))
207-
normalized_exp = (t - max) |> Nx.exp()
208-
normalized_exp / Nx.sum(normalized_exp, axes: [-1], keep_axes: true)
209-
end
210-
211239
@doc """
212240
Makes predictions with the given `model` on inputs `x`.
213241
@@ -219,14 +247,16 @@ defmodule Scholar.Linear.LogisticRegression do
219247
iex> y = Nx.tensor([1, 0, 1])
220248
iex> model = Scholar.Linear.LogisticRegression.fit(x, y, num_classes: 2)
221249
iex> Scholar.Linear.LogisticRegression.predict(model, Nx.tensor([[-3.0, 5.0]]))
222-
#Nx.Tensor<
223-
s32[1]
224-
[1]
225-
>
250+
Nx.tensor([1])
226251
"""
227252
defn predict(%__MODULE__{coefficients: coeff, bias: bias} = _model, x) do
228-
inter = Nx.dot(x, [1], coeff, [0]) + bias
229-
Nx.argmax(inter, axis: 1)
253+
if Nx.rank(x) != 2 do
254+
raise ArgumentError,
255+
"expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
256+
end
257+
258+
logits = Nx.dot(x, coeff) + bias
259+
Nx.argmax(logits, axis: 1)
230260
end
231261

232262
@doc """
@@ -238,14 +268,14 @@ defmodule Scholar.Linear.LogisticRegression do
238268
iex> y = Nx.tensor([1, 0, 1])
239269
iex> model = Scholar.Linear.LogisticRegression.fit(x, y, num_classes: 2)
240270
iex> Scholar.Linear.LogisticRegression.predict_probability(model, Nx.tensor([[-3.0, 5.0]]))
241-
#Nx.Tensor<
242-
f32[1][2]
243-
[
244-
[6.470913388456623e-11, 1.0]
245-
]
246-
>
271+
Nx.tensor([[0.10075931251049042, 0.8992406725883484]])
247272
"""
248273
defn predict_probability(%__MODULE__{coefficients: coeff, bias: bias} = _model, x) do
249-
softmax(Nx.dot(x, [1], coeff, [0]) + bias)
274+
if Nx.rank(x) != 2 do
275+
raise ArgumentError,
276+
"expected x to have shape {n_samples, n_features}, got tensor with shape: #{inspect(Nx.shape(x))}"
277+
end
278+
279+
softmax(Nx.dot(x, coeff) + bias)
250280
end
251281
end

lib/scholar/model_selection.ex

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ defmodule Scholar.ModelSelection do
178178
iex> y = Nx.tensor([0, 1, 2, 0, 1, 1, 0])
179179
iex> opts = [
180180
...> num_classes: [3],
181-
...> iterations: [10, 20, 50],
182-
...> optimizer: [Polaris.Optimizers.adam(learning_rate: 0.005), Polaris.Optimizers.adam(learning_rate: 0.01)],
181+
...> max_iterations: [10, 20, 50],
182+
...> alpha: [0.0, 0.1, 1.0],
183183
...> ]
184184
iex> Scholar.ModelSelection.grid_search(x, y, folding_fun, scoring_fun, opts)
185185
"""

0 commit comments

Comments
 (0)