@@ -116,17 +116,24 @@ defmodule Nx.BinaryBackend.Matrix do
116
116
117
117
defp do_ts ( [ ] , [ ] , _idx , acc ) , do: acc
118
118
119
- defp qr_decomposition ( matrix , m , n , eps ) when m >= n do
119
+ defp qr_decomposition ( matrix , n , _eps ) when n in 0 .. 1 do
120
+ { [ [ 1.0 ] ] , matrix }
121
+ end
122
+
123
+ defp qr_decomposition ( matrix , n , eps ) when n >= 2 do
120
124
# QR decomposition is performed by using Householder transform
121
- max_i = if m == n , do: n - 2 , else: n - 1
125
+ # this function originally supported generic QR, but
126
+ # it is now only used by eigh. Because of this,
127
+ # we simplified the function signature to only
128
+ # support square matrices.
122
129
123
130
{ q_matrix , r_matrix } =
124
- for i <- 0 .. max_i , reduce: { nil , matrix } do
131
+ for i <- 0 .. ( n - 2 ) // 1 , reduce: { nil , matrix } do
125
132
{ q , r } ->
126
133
h =
127
134
r
128
- |> slice_matrix ( [ i , i ] , [ m - i , 1 ] )
129
- |> householder_reflector ( m , eps )
135
+ |> slice_matrix ( [ i , i ] , [ n - i , 1 ] )
136
+ |> householder_reflector ( n , eps )
130
137
131
138
# If we haven't allocated Q yet, let Q = H1
132
139
# TODO: Resolve inconsistent with the Householder reflector.
@@ -145,10 +152,6 @@ defmodule Nx.BinaryBackend.Matrix do
145
152
{ approximate_zeros ( q_matrix , eps ) , approximate_zeros ( r_matrix , eps ) }
146
153
end
147
154
148
- defp qr_decomposition ( _ , _ , _ , _ ) do
149
- raise ArgumentError , "tensor must have at least as many rows as columns"
150
- end
151
-
152
155
defp raise_not_hermitian do
153
156
raise ArgumentError ,
154
157
"matrix must be hermitian, a matrix is hermitian iff X = adjoint(X)"
@@ -178,7 +181,7 @@ defmodule Nx.BinaryBackend.Matrix do
178
181
{ eigenvals_diag , eigenvecs } =
179
182
Enum . reduce_while ( 1 .. max_iter // 1 , { h , q_h } , fn _ , { a_old , q_old } ->
180
183
# QR decomposition
181
- { q_now , r_now } = qr_decomposition ( a_old , n , n , eps )
184
+ { q_now , r_now } = qr_decomposition ( a_old , n , eps )
182
185
183
186
# Update matrix A, Q
184
187
a_new = dot_matrix_real ( r_now , q_now )
@@ -203,10 +206,14 @@ defmodule Nx.BinaryBackend.Matrix do
203
206
eigenvecs |> approximate_zeros ( eps ) |> matrix_to_binary ( output_type ) }
204
207
end
205
208
209
+ defp hessenberg_decomposition ( matrix , n , _eps ) when n in 0 .. 1 do
210
+ { matrix , [ [ 1.0 ] ] }
211
+ end
212
+
206
213
defp hessenberg_decomposition ( matrix , n , eps ) do
207
214
# Hessenberg decomposition is performed by using Householder transform
208
215
{ hess_matrix , q_matrix } =
209
- for i <- 0 .. ( n - 2 ) , reduce: { matrix , nil } do
216
+ for i <- 0 .. ( n - 2 ) // 1 , reduce: { matrix , nil } do
210
217
{ hess , q } ->
211
218
h =
212
219
hess
0 commit comments