Skip to content

Commit 1013a52

Browse files
authored
fix: enforce always-positive ranges (#1529)
1 parent 6c7e672 commit 1013a52

File tree

1 file changed

+18
-11
lines changed

1 file changed

+18
-11
lines changed

nx/lib/nx/binary_backend/matrix.ex

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -116,17 +116,24 @@ defmodule Nx.BinaryBackend.Matrix do
116116

117117
defp do_ts([], [], _idx, acc), do: acc
118118

119-
defp qr_decomposition(matrix, m, n, eps) when m >= n do
119+
defp qr_decomposition(matrix, n, _eps) when n in 0..1 do
120+
{[[1.0]], matrix}
121+
end
122+
123+
defp qr_decomposition(matrix, n, eps) when n >= 2 do
120124
# QR decomposition is performed by using Householder transform
121-
max_i = if m == n, do: n - 2, else: n - 1
125+
# this function originally supported generic QR, but
126+
# it is now only used by eigh. Because of this,
127+
# we simplified the function signature to only
128+
# support square matrices.
122129

123130
{q_matrix, r_matrix} =
124-
for i <- 0..max_i, reduce: {nil, matrix} do
131+
for i <- 0..(n - 2)//1, reduce: {nil, matrix} do
125132
{q, r} ->
126133
h =
127134
r
128-
|> slice_matrix([i, i], [m - i, 1])
129-
|> householder_reflector(m, eps)
135+
|> slice_matrix([i, i], [n - i, 1])
136+
|> householder_reflector(n, eps)
130137

131138
# If we haven't allocated Q yet, let Q = H1
132139
# TODO: Resolve inconsistent with the Householder reflector.
@@ -145,10 +152,6 @@ defmodule Nx.BinaryBackend.Matrix do
145152
{approximate_zeros(q_matrix, eps), approximate_zeros(r_matrix, eps)}
146153
end
147154

148-
defp qr_decomposition(_, _, _, _) do
149-
raise ArgumentError, "tensor must have at least as many rows as columns"
150-
end
151-
152155
defp raise_not_hermitian do
153156
raise ArgumentError,
154157
"matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)"
@@ -178,7 +181,7 @@ defmodule Nx.BinaryBackend.Matrix do
178181
{eigenvals_diag, eigenvecs} =
179182
Enum.reduce_while(1..max_iter//1, {h, q_h}, fn _, {a_old, q_old} ->
180183
# QR decomposition
181-
{q_now, r_now} = qr_decomposition(a_old, n, n, eps)
184+
{q_now, r_now} = qr_decomposition(a_old, n, eps)
182185

183186
# Update matrix A, Q
184187
a_new = dot_matrix_real(r_now, q_now)
@@ -203,10 +206,14 @@ defmodule Nx.BinaryBackend.Matrix do
203206
eigenvecs |> approximate_zeros(eps) |> matrix_to_binary(output_type)}
204207
end
205208

209+
defp hessenberg_decomposition(matrix, n, _eps) when n in 0..1 do
210+
{matrix, [[1.0]]}
211+
end
212+
206213
defp hessenberg_decomposition(matrix, n, eps) do
207214
# Hessenberg decomposition is performed by using Householder transform
208215
{hess_matrix, q_matrix} =
209-
for i <- 0..(n - 2), reduce: {matrix, nil} do
216+
for i <- 0..(n - 2)//1, reduce: {matrix, nil} do
210217
{hess, q} ->
211218
h =
212219
hess

0 commit comments

Comments
 (0)