Skip to content

Commit 9d73de2

Browse files
authored
fix: least_squares implementation (#1550)
1 parent 5eb444e commit 9d73de2

File tree

2 files changed

+12
-14
lines changed

2 files changed

+12
-14
lines changed

exla/test/exla/nx_linalg_doctest_test.exs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ defmodule EXLA.MLIR.NxLinAlgDoctestTest do
1010
invert: 1,
1111
matrix_power: 2
1212
]
13-
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 2]
13+
@rounding_error_doctests [triangular_solve: 3, eigh: 2, cholesky: 1, least_squares: 3]
1414

1515
@excluded_doctests @function_clause_error_doctests ++
1616
@rounding_error_doctests ++

nx/lib/nx/lin_alg.ex

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,12 +2152,16 @@ defmodule Nx.LinAlg do
21522152
@doc """
21532153
Return the least-squares solution to a linear matrix equation Ax = b.
21542154
2155+
## Options
2156+
2157+
* `:eps` - Rounding error threshold used to assume values as 0. Defaults to `1.0e-15`
2158+
21552159
## Examples
21562160
21572161
iex> Nx.LinAlg.least_squares(Nx.tensor([[1, 2], [2, 3]]), Nx.tensor([1, 2]))
21582162
#Nx.Tensor<
21592163
f32[2]
2160-
[1.0000004768371582, -2.665601925855299e-7]
2164+
[0.9977624416351318, 0.0011188983917236328]
21612165
>
21622166
21632167
iex> Nx.LinAlg.least_squares(Nx.tensor([[0, 1], [1, 1], [2, 1], [3, 1]]), Nx.tensor([-1, 0.2, 0.9, 2.1]))
@@ -2187,7 +2191,9 @@ defmodule Nx.LinAlg do
21872191
** (ArgumentError) the number of rows of the matrix as the 1st argument and the number of columns of the vector as the 2nd argument must be the same, got 1st argument shape {2, 2} and 2nd argument shape {3}
21882192
"""
21892193
@doc from_backend: false
2190-
defn least_squares(a, b) do
2194+
defn least_squares(a, b, opts \\ []) do
2195+
opts = keyword!(opts, eps: 1.0e-15)
2196+
21912197
%T{type: a_type, shape: a_shape} = Nx.to_tensor(a)
21922198
a_size = Nx.rank(a_shape)
21932199
%T{type: b_type, shape: b_shape} = Nx.to_tensor(b)
@@ -2235,17 +2241,9 @@ defmodule Nx.LinAlg do
22352241
)
22362242
end
22372243

2238-
case a_shape do
2239-
{m, n} when m == n ->
2240-
Nx.LinAlg.solve(a, b)
2241-
2242-
{m, n} when m != n ->
2243-
Nx.LinAlg.pinv(a, eps: 1.0e-15)
2244-
|> Nx.dot(b)
2245-
2246-
_ ->
2247-
nil
2248-
end
2244+
a
2245+
|> Nx.LinAlg.pinv(eps: opts[:eps])
2246+
|> Nx.dot(b)
22492247
end
22502248

22512249
defp apply_vectorized(tensor, fun) when is_function(fun, 1) do

0 commit comments

Comments
 (0)