Skip to content

Commit 27d8bb3

Browse files
committed
Adding documentation for lapack.py
1 parent fdbd933 commit 27d8bb3

File tree

1 file changed

+263
-3
lines changed

1 file changed

+263
-3
lines changed

arrayfire/lapack.py

Lines changed: 263 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,66 +7,301 @@
77
# http://arrayfire.com/licenses/BSD-3-Clause
88
########################################################
99

10+
"""
11+
dense linear algebra functions for arrayfire.
12+
"""
13+
1014
from .library import *
1115
from .array import *
1216

1317
def lu(A):
18+
"""
19+
LU decomposition.
20+
21+
Parameters
22+
----------
23+
A: af.Array
24+
A 2 dimensional arrayfire array.
25+
26+
Returns
27+
-------
28+
(L,U,P): tuple of af.Arrays
29+
- L - Lower triangular matrix.
30+
- U - Upper triangular matrix.
31+
- P - Permutation array.
32+
33+
Note
34+
----
35+
36+
The original matrix `A` can be reconstructed using the outputs in the following manner.
37+
38+
>>> A[P, :] = af.matmul(L, U)
39+
40+
"""
1441
L = Array()
1542
U = Array()
1643
P = Array()
1744
safe_call(backend.get().af_lu(ct.pointer(L.arr), ct.pointer(U.arr), ct.pointer(P.arr), A.arr))
1845
return L,U,P
1946

2047
def lu_inplace(A, pivot="lapack"):
48+
"""
49+
In place LU decomposition.
50+
51+
Parameters
52+
----------
53+
A: af.Array
54+
- a 2 dimensional arrayfire array on entry.
55+
- Contains L in the lower triangle on exit.
56+
- Contains U in the upper triangle on exit.
57+
58+
Returns
59+
-------
60+
P: af.Array
61+
- Permutation array.
62+
63+
Note
64+
----
65+
66+
This function is primarily used with `af.solve_lu` to reduce computations.
67+
68+
"""
2169
P = Array()
2270
is_pivot_lapack = False if (pivot == "full") else True
2371
safe_call(backend.get().af_lu_inplace(ct.pointer(P.arr), A.arr, is_pivot_lapack))
2472
return P
2573

2674
def qr(A):
75+
"""
76+
QR decomposition.
77+
78+
Parameters
79+
----------
80+
A: af.Array
81+
A 2 dimensional arrayfire array.
82+
83+
Returns
84+
-------
85+
(Q,R,T): tuple of af.Arrays
86+
- Q - Orthogonal matrix.
87+
- R - Upper triangular matrix.
88+
- T - Vector containing additional information to solve a least squares problem.
89+
90+
Note
91+
----
92+
93+
The outputs of this funciton have the following properties.
94+
95+
>>> A = af.matmul(Q, R)
96+
>>> I = af.matmulNT(Q, Q) # Identity matrix
97+
"""
2798
Q = Array()
2899
R = Array()
29100
T = Array()
30101
safe_call(backend.get().af_lu(ct.pointer(Q.arr), ct.pointer(R.arr), ct.pointer(T.arr), A.arr))
31102
return Q,R,T
32103

33104
def qr_inplace(A):
105+
"""
106+
In place QR decomposition.
107+
108+
Parameters
109+
----------
110+
A: af.Array
111+
- a 2 dimensional arrayfire array on entry.
112+
- Packed Q and R matrices on exit.
113+
114+
Returns
115+
-------
116+
T: af.Array
117+
- Vector containing additional information to solve a least squares problem.
118+
119+
Note
120+
----
121+
122+
This function is used to save space only when `R` is required.
123+
"""
34124
T = Array()
35125
safe_call(backend.get().af_qr_inplace(ct.pointer(T.arr), A.arr))
36126
return T
37127

38128
def cholesky(A, is_upper=True):
129+
"""
130+
Cholesky decomposition
131+
132+
Parameters
133+
----------
134+
A: af.Array
135+
A 2 dimensional, symmetric, positive definite matrix.
136+
137+
is_upper: optional: bool. default: True
138+
Specifies if output `R` is upper triangular (if True) or lower triangular (if False).
139+
140+
Returns
141+
-------
142+
(R,info): tuple of af.Array, int.
143+
- R - triangular matrix.
144+
- info - 0 if decomposition sucessful.
145+
Note
146+
----
147+
148+
The original matrix `A` can be reconstructed using the outputs in the following manner.
149+
150+
>>> A = af.matmulNT(R, R) #if R is upper triangular
151+
152+
"""
39153
R = Array()
40154
info = ct.c_int(0)
41155
safe_call(backend.get().af_cholesky(ct.pointer(R.arr), ct.pointer(info), A.arr, is_upper))
42156
return R, info.value
43157

44158
def cholesky_inplace(A, is_upper=True):
159+
"""
160+
In place Cholesky decomposition.
161+
162+
Parameters
163+
----------
164+
A: af.Array
165+
- a 2 dimensional, symmetric, positive definite matrix.
166+
- Trinangular matrix on exit.
167+
168+
is_upper: optional: bool. default: True.
169+
Specifies if output `R` is upper triangular (if True) or lower triangular (if False).
170+
171+
Returns
172+
-------
173+
info : int.
174+
0 if decomposition sucessful.
175+
176+
"""
45177
info = ct.c_int(0)
46178
safe_call(backend.get().af_cholesky_inplace(ct.pointer(info), A.arr, is_upper))
47179
return info.value
48180

49181
def solve(A, B, options=MATPROP.NONE):
182+
"""
183+
Solve a system of linear equations.
184+
185+
Parameters
186+
----------
187+
188+
A: af.Array
189+
A 2 dimensional arrayfire array representing the coefficients of the system.
190+
191+
B: af.Array
192+
A 1 or 2 dimensional arrayfire array representing the constants of the system.
193+
194+
options: optional: af.MATPROP. default: af.MATPROP.NONE.
195+
- Additional options to speed up computations.
196+
- Currently needs to be one of `af.MATPROP.NONE`, `af.MATPROP.LOWER`, `af.MATPROP.UPPER`.
197+
198+
Returns
199+
-------
200+
X: af.Array
201+
A 1 or 2 dimensional arrayfire array representing the unknowns in the system.
202+
203+
"""
50204
X = Array()
51205
safe_call(backend.get().af_solve(ct.pointer(X.arr), A.arr, B.arr, options.value))
52206
return X
53207

54208
def solve_lu(A, P, B, options=MATPROP.NONE):
209+
"""
210+
Solve a system of linear equations, using LU decomposition.
211+
212+
Parameters
213+
----------
214+
215+
A: af.Array
216+
- A 2 dimensional arrayfire array representing the coefficients of the system.
217+
- This matrix should be decomposed previously using `lu_inplace(A)`.
218+
219+
P: af.Array
220+
- Permutation array.
221+
- This array is the output of an earlier call to `lu_inplace(A)`
222+
223+
B: af.Array
224+
A 1 or 2 dimensional arrayfire array representing the constants of the system.
225+
226+
Returns
227+
-------
228+
X: af.Array
229+
A 1 or 2 dimensional arrayfire array representing the unknowns in the system.
230+
231+
"""
55232
X = Array()
56233
safe_call(backend.get().af_solve_lu(ct.pointer(X.arr), A.arr, P.arr, B.arr, options.value))
57234
return X
58235

59236
def inverse(A, options=MATPROP.NONE):
60-
I = Array()
61-
safe_call(backend.get().af_inverse(ct.pointer(I.arr), A.arr, options.value))
62-
return I
237+
"""
238+
Invert a matrix.
239+
240+
Parameters
241+
----------
242+
243+
A: af.Array
244+
- A 2 dimensional arrayfire array
245+
246+
options: optional: af.MATPROP. default: af.MATPROP.NONE.
247+
- Additional options to speed up computations.
248+
- Currently needs to be one of `af.MATPROP.NONE`.
249+
250+
Returns
251+
-------
252+
253+
AI: af.Array
254+
- A 2 dimensional array that is the inverse of `A`
255+
256+
Note
257+
----
258+
259+
`A` needs to be a square matrix.
260+
261+
"""
262+
AI = Array()
263+
safe_call(backend.get().af_inverse(ct.pointer(AI.arr), A.arr, options.value))
264+
return AI
63265

64266
def rank(A, tol=1E-5):
267+
"""
268+
Rank of a matrix.
269+
270+
Parameters
271+
----------
272+
273+
A: af.Array
274+
- A 2 dimensional arrayfire array
275+
276+
tol: optional: scalar. default: 1E-5.
277+
- Tolerance for calculating rank
278+
279+
Returns
280+
-------
281+
282+
r: int
283+
- Rank of `A` within the given tolerance
284+
"""
65285
r = ct.c_uint(0)
66286
safe_call(backend.get().af_rank(ct.pointer(r), A.arr, ct.c_double(tol)))
67287
return r.value
68288

69289
def det(A):
290+
"""
291+
Determinant of a matrix.
292+
293+
Parameters
294+
----------
295+
296+
A: af.Array
297+
- A 2 dimensional arrayfire array
298+
299+
Returns
300+
-------
301+
302+
res: scalar
303+
- Determinant of the matrix.
304+
"""
70305
re = ct.c_double(0)
71306
im = ct.c_double(0)
72307
safe_call(backend.get().af_det(ct.pointer(re), ct.pointer(im), A.arr))
@@ -75,6 +310,31 @@ def det(A):
75310
return re if (im == 0) else re + im * 1j
76311

77312
def norm(A, norm_type=NORM.EUCLID, p=1.0, q=1.0):
313+
"""
314+
Norm of an array or a matrix.
315+
316+
Parameters
317+
----------
318+
319+
A: af.Array
320+
- A 1 or 2 dimensional arrayfire array
321+
322+
norm_type: optional: af.NORM. default: af.NORM.EUCLID.
323+
- Type of norm to be calculated.
324+
325+
p: scalar. default 1.0.
326+
- Used only if `norm_type` is one of `af.NORM.VECTOR_P`, `af.NORM_MATRIX_L_PQ`
327+
328+
q: scalar. default 1.0.
329+
- Used only if `norm_type` is `af.NORM_MATRIX_L_PQ`
330+
331+
Returns
332+
-------
333+
334+
res: scalar
335+
- norm of the input
336+
337+
"""
78338
res = ct.c_double(0)
79339
safe_call(backend.get().af_norm(ct.pointer(res), A.arr, norm_type.value,
80340
ct.c_double(p), ct.c_double(q)))

0 commit comments

Comments
 (0)