7
7
# http://arrayfire.com/licenses/BSD-3-Clause
8
8
########################################################
9
9
10
+ """
11
+ dense linear algebra functions for arrayfire.
12
+ """
13
+
10
14
from .library import *
11
15
from .array import *
12
16
13
17
def lu (A ):
18
+ """
19
+ LU decomposition.
20
+
21
+ Parameters
22
+ ----------
23
+ A: af.Array
24
+ A 2 dimensional arrayfire array.
25
+
26
+ Returns
27
+ -------
28
+ (L,U,P): tuple of af.Arrays
29
+ - L - Lower triangular matrix.
30
+ - U - Upper triangular matrix.
31
+ - P - Permutation array.
32
+
33
+ Note
34
+ ----
35
+
36
+ The original matrix `A` can be reconstructed using the outputs in the following manner.
37
+
38
+ >>> A[P, :] = af.matmul(L, U)
39
+
40
+ """
14
41
L = Array ()
15
42
U = Array ()
16
43
P = Array ()
17
44
safe_call (backend .get ().af_lu (ct .pointer (L .arr ), ct .pointer (U .arr ), ct .pointer (P .arr ), A .arr ))
18
45
return L ,U ,P
19
46
20
47
def lu_inplace (A , pivot = "lapack" ):
48
+ """
49
+ In place LU decomposition.
50
+
51
+ Parameters
52
+ ----------
53
+ A: af.Array
54
+ - a 2 dimensional arrayfire array on entry.
55
+ - Contains L in the lower triangle on exit.
56
+ - Contains U in the upper triangle on exit.
57
+
58
+ Returns
59
+ -------
60
+ P: af.Array
61
+ - Permutation array.
62
+
63
+ Note
64
+ ----
65
+
66
+ This function is primarily used with `af.solve_lu` to reduce computations.
67
+
68
+ """
21
69
P = Array ()
22
70
is_pivot_lapack = False if (pivot == "full" ) else True
23
71
safe_call (backend .get ().af_lu_inplace (ct .pointer (P .arr ), A .arr , is_pivot_lapack ))
24
72
return P
25
73
26
74
def qr (A ):
75
+ """
76
+ QR decomposition.
77
+
78
+ Parameters
79
+ ----------
80
+ A: af.Array
81
+ A 2 dimensional arrayfire array.
82
+
83
+ Returns
84
+ -------
85
+ (Q,R,T): tuple of af.Arrays
86
+ - Q - Orthogonal matrix.
87
+ - R - Upper triangular matrix.
88
+ - T - Vector containing additional information to solve a least squares problem.
89
+
90
+ Note
91
+ ----
92
+
93
+ The outputs of this funciton have the following properties.
94
+
95
+ >>> A = af.matmul(Q, R)
96
+ >>> I = af.matmulNT(Q, Q) # Identity matrix
97
+ """
27
98
Q = Array ()
28
99
R = Array ()
29
100
T = Array ()
30
101
safe_call (backend .get ().af_lu (ct .pointer (Q .arr ), ct .pointer (R .arr ), ct .pointer (T .arr ), A .arr ))
31
102
return Q ,R ,T
32
103
33
104
def qr_inplace (A ):
105
+ """
106
+ In place QR decomposition.
107
+
108
+ Parameters
109
+ ----------
110
+ A: af.Array
111
+ - a 2 dimensional arrayfire array on entry.
112
+ - Packed Q and R matrices on exit.
113
+
114
+ Returns
115
+ -------
116
+ T: af.Array
117
+ - Vector containing additional information to solve a least squares problem.
118
+
119
+ Note
120
+ ----
121
+
122
+ This function is used to save space only when `R` is required.
123
+ """
34
124
T = Array ()
35
125
safe_call (backend .get ().af_qr_inplace (ct .pointer (T .arr ), A .arr ))
36
126
return T
37
127
38
128
def cholesky (A , is_upper = True ):
129
+ """
130
+ Cholesky decomposition
131
+
132
+ Parameters
133
+ ----------
134
+ A: af.Array
135
+ A 2 dimensional, symmetric, positive definite matrix.
136
+
137
+ is_upper: optional: bool. default: True
138
+ Specifies if output `R` is upper triangular (if True) or lower triangular (if False).
139
+
140
+ Returns
141
+ -------
142
+ (R,info): tuple of af.Array, int.
143
+ - R - triangular matrix.
144
+ - info - 0 if decomposition sucessful.
145
+ Note
146
+ ----
147
+
148
+ The original matrix `A` can be reconstructed using the outputs in the following manner.
149
+
150
+ >>> A = af.matmulNT(R, R) #if R is upper triangular
151
+
152
+ """
39
153
R = Array ()
40
154
info = ct .c_int (0 )
41
155
safe_call (backend .get ().af_cholesky (ct .pointer (R .arr ), ct .pointer (info ), A .arr , is_upper ))
42
156
return R , info .value
43
157
44
158
def cholesky_inplace (A , is_upper = True ):
159
+ """
160
+ In place Cholesky decomposition.
161
+
162
+ Parameters
163
+ ----------
164
+ A: af.Array
165
+ - a 2 dimensional, symmetric, positive definite matrix.
166
+ - Trinangular matrix on exit.
167
+
168
+ is_upper: optional: bool. default: True.
169
+ Specifies if output `R` is upper triangular (if True) or lower triangular (if False).
170
+
171
+ Returns
172
+ -------
173
+ info : int.
174
+ 0 if decomposition sucessful.
175
+
176
+ """
45
177
info = ct .c_int (0 )
46
178
safe_call (backend .get ().af_cholesky_inplace (ct .pointer (info ), A .arr , is_upper ))
47
179
return info .value
48
180
49
181
def solve (A , B , options = MATPROP .NONE ):
182
+ """
183
+ Solve a system of linear equations.
184
+
185
+ Parameters
186
+ ----------
187
+
188
+ A: af.Array
189
+ A 2 dimensional arrayfire array representing the coefficients of the system.
190
+
191
+ B: af.Array
192
+ A 1 or 2 dimensional arrayfire array representing the constants of the system.
193
+
194
+ options: optional: af.MATPROP. default: af.MATPROP.NONE.
195
+ - Additional options to speed up computations.
196
+ - Currently needs to be one of `af.MATPROP.NONE`, `af.MATPROP.LOWER`, `af.MATPROP.UPPER`.
197
+
198
+ Returns
199
+ -------
200
+ X: af.Array
201
+ A 1 or 2 dimensional arrayfire array representing the unknowns in the system.
202
+
203
+ """
50
204
X = Array ()
51
205
safe_call (backend .get ().af_solve (ct .pointer (X .arr ), A .arr , B .arr , options .value ))
52
206
return X
53
207
54
208
def solve_lu (A , P , B , options = MATPROP .NONE ):
209
+ """
210
+ Solve a system of linear equations, using LU decomposition.
211
+
212
+ Parameters
213
+ ----------
214
+
215
+ A: af.Array
216
+ - A 2 dimensional arrayfire array representing the coefficients of the system.
217
+ - This matrix should be decomposed previously using `lu_inplace(A)`.
218
+
219
+ P: af.Array
220
+ - Permutation array.
221
+ - This array is the output of an earlier call to `lu_inplace(A)`
222
+
223
+ B: af.Array
224
+ A 1 or 2 dimensional arrayfire array representing the constants of the system.
225
+
226
+ Returns
227
+ -------
228
+ X: af.Array
229
+ A 1 or 2 dimensional arrayfire array representing the unknowns in the system.
230
+
231
+ """
55
232
X = Array ()
56
233
safe_call (backend .get ().af_solve_lu (ct .pointer (X .arr ), A .arr , P .arr , B .arr , options .value ))
57
234
return X
58
235
59
236
def inverse (A , options = MATPROP .NONE ):
60
- I = Array ()
61
- safe_call (backend .get ().af_inverse (ct .pointer (I .arr ), A .arr , options .value ))
62
- return I
237
+ """
238
+ Invert a matrix.
239
+
240
+ Parameters
241
+ ----------
242
+
243
+ A: af.Array
244
+ - A 2 dimensional arrayfire array
245
+
246
+ options: optional: af.MATPROP. default: af.MATPROP.NONE.
247
+ - Additional options to speed up computations.
248
+ - Currently needs to be one of `af.MATPROP.NONE`.
249
+
250
+ Returns
251
+ -------
252
+
253
+ AI: af.Array
254
+ - A 2 dimensional array that is the inverse of `A`
255
+
256
+ Note
257
+ ----
258
+
259
+ `A` needs to be a square matrix.
260
+
261
+ """
262
+ AI = Array ()
263
+ safe_call (backend .get ().af_inverse (ct .pointer (AI .arr ), A .arr , options .value ))
264
+ return AI
63
265
64
266
def rank (A , tol = 1E-5 ):
267
+ """
268
+ Rank of a matrix.
269
+
270
+ Parameters
271
+ ----------
272
+
273
+ A: af.Array
274
+ - A 2 dimensional arrayfire array
275
+
276
+ tol: optional: scalar. default: 1E-5.
277
+ - Tolerance for calculating rank
278
+
279
+ Returns
280
+ -------
281
+
282
+ r: int
283
+ - Rank of `A` within the given tolerance
284
+ """
65
285
r = ct .c_uint (0 )
66
286
safe_call (backend .get ().af_rank (ct .pointer (r ), A .arr , ct .c_double (tol )))
67
287
return r .value
68
288
69
289
def det (A ):
290
+ """
291
+ Determinant of a matrix.
292
+
293
+ Parameters
294
+ ----------
295
+
296
+ A: af.Array
297
+ - A 2 dimensional arrayfire array
298
+
299
+ Returns
300
+ -------
301
+
302
+ res: scalar
303
+ - Determinant of the matrix.
304
+ """
70
305
re = ct .c_double (0 )
71
306
im = ct .c_double (0 )
72
307
safe_call (backend .get ().af_det (ct .pointer (re ), ct .pointer (im ), A .arr ))
@@ -75,6 +310,31 @@ def det(A):
75
310
return re if (im == 0 ) else re + im * 1j
76
311
77
312
def norm (A , norm_type = NORM .EUCLID , p = 1.0 , q = 1.0 ):
313
+ """
314
+ Norm of an array or a matrix.
315
+
316
+ Parameters
317
+ ----------
318
+
319
+ A: af.Array
320
+ - A 1 or 2 dimensional arrayfire array
321
+
322
+ norm_type: optional: af.NORM. default: af.NORM.EUCLID.
323
+ - Type of norm to be calculated.
324
+
325
+ p: scalar. default 1.0.
326
+ - Used only if `norm_type` is one of `af.NORM.VECTOR_P`, `af.NORM_MATRIX_L_PQ`
327
+
328
+ q: scalar. default 1.0.
329
+ - Used only if `norm_type` is `af.NORM_MATRIX_L_PQ`
330
+
331
+ Returns
332
+ -------
333
+
334
+ res: scalar
335
+ - norm of the input
336
+
337
+ """
78
338
res = ct .c_double (0 )
79
339
safe_call (backend .get ().af_norm (ct .pointer (res ), A .arr , norm_type .value ,
80
340
ct .c_double (p ), ct .c_double (q )))
0 commit comments