Skip to content

Commit 2c79746

Browse files
committed
FEAT: Adding SVD
1 parent 6234d53 commit 2c79746

File tree

2 files changed

+77
-1
lines changed

2 files changed

+77
-1
lines changed

arrayfire/lapack.py

Lines changed: 70 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
########################################################
99

1010
"""
11-
dense linear algebra functions for arrayfire.
11+
Dense Linear Algebra functions for arrayfire.
1212
"""
1313

1414
from .library import *
@@ -339,3 +339,72 @@ def norm(A, norm_type=NORM.EUCLID, p=1.0, q=1.0):
339339
safe_call(backend.get().af_norm(ct.pointer(res), A.arr, norm_type.value,
340340
ct.c_double(p), ct.c_double(q)))
341341
return res.value
342+
343+
def svd(A):
344+
"""
345+
Singular Value Decomposition
346+
347+
Parameters
348+
----------
349+
A: af.Array
350+
A 2 dimensional arrayfire array.
351+
352+
Returns
353+
-------
354+
(U,S,Vt): tuple of af.Arrays
355+
- U - A unitary matrix
356+
- S - An array containing the elements of diagonal matrix
357+
- Vt - A unitary matrix
358+
359+
Note
360+
----
361+
362+
- The original matrix `A` is preserved and additional storage space is required for decomposition.
363+
364+
- If the original matrix `A` need not be preserved, use `svd_inplace` instead.
365+
366+
- The original matrix `A` can be reconstructed using the outputs in the following manner.
367+
>>> Smat = af.diag(S, 0, False)
368+
>>> A_recon = af.matmul(af.matmul(U, Smat), Vt)
369+
370+
"""
371+
U = Array()
372+
S = Array()
373+
Vt = Array()
374+
safe_call(backend.get().af_svd(ct.pointer(U.arr), ct.pointer(S.arr), ct.pointer(Vt.arr), A.arr))
375+
return U, S, Vt
376+
377+
def svd_inplace(A):
378+
"""
379+
Singular Value Decomposition
380+
381+
Parameters
382+
----------
383+
A: af.Array
384+
A 2 dimensional arrayfire array.
385+
386+
Returns
387+
-------
388+
(U,S,Vt): tuple of af.Arrays
389+
- U - A unitary matrix
390+
- S - An array containing the elements of diagonal matrix
391+
- Vt - A unitary matrix
392+
393+
Note
394+
----
395+
396+
- The original matrix `A` is not preserved.
397+
398+
- If the original matrix `A` needs to be preserved, use `svd` instead.
399+
400+
- The original matrix `A` can be reconstructed using the outputs in the following manner.
401+
>>> Smat = af.diag(S, 0, False)
402+
>>> A_recon = af.matmul(af.matmul(U, Smat), Vt)
403+
404+
"""
405+
U = Array()
406+
S = Array()
407+
Vt = Array()
408+
safe_call(backend.get().af_svd_inplace(ct.pointer(U.arr), ct.pointer(S.arr), ct.pointer(Vt.arr),
409+
A.arr))
410+
return U, S, Vt

tests/simple/lapack.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,4 +74,11 @@ def simple_lapack(verbose=False):
7474
print_func(af.norm(a, af.NORM.MATRIX_INF))
7575
print_func(af.norm(a, af.NORM.MATRIX_L_PQ, 1, 1))
7676

77+
a = af.randu(10,10)
78+
display_func(a)
79+
u,s,vt = af.svd(a)
80+
display_func(af.matmul(af.matmul(u, af.diag(s, 0, False)), vt))
81+
u,s,vt = af.svd_inplace(a)
82+
display_func(af.matmul(af.matmul(u, af.diag(s, 0, False)), vt))
83+
7784
_util.tests['lapack'] = simple_lapack

0 commit comments

Comments
 (0)