1
+ function mixedtrsm! (side, uplo, trans, diag, alpha, A, B, StoragePrecision)
2
+ T = StoragePrecision
3
+ if typeof (B) != Matrix{T}
4
+ println (" B is not of type $T but of type $(typeof (B)) " )
5
+ if typeof (A) != Matrix{T}
6
+ Acopy = convert (Matrix{T}, A)
7
+ else
8
+ Acopy = A
9
+ end
10
+ Bcopy = convert (Matrix{T}, B)
11
+ BLAS. trsm! (side, uplo, trans, diag, T (alpha), Acopy, Bcopy)
12
+ end
13
+ BLAS. trsm! (side, uplo, trans, diag, alpha, A, B)
14
+ end
15
+ function mixedgemm! (transa, transb, alpha, A, B, beta, C, StoragePrecision)
16
+ T = StoragePrecision
17
+ if typeof (C) != Matrix{T}
18
+ if typeof (A) != Matrix{T}
19
+ Acopy = convert (Matrix{T}, A)
20
+ else
21
+ Acopy = A
22
+ end
23
+ if typeof (B) != Matrix{T}
24
+ Bcopy = convert (Matrix{T}, B)
25
+ else
26
+ Bcopy = B
27
+ end
28
+ Ccopy = convert (Matrix{T}, C)
29
+ BLAS. gemm! (transa, transb, T (alpha), Acopy, Bcopy, T (beta), Ccopy)
30
+ end
31
+ BLAS. gemm! (transa, transb, alpha, A, B, beta, C)
32
+ end
33
+ function mixedsyrk! (uplo, trans, alpha, A, beta, C, StoragePrecision)
34
+ T = StoragePrecision
35
+ if typeof (C) != Matrix{T}
36
+ if typeof (A) != Matrix{T}
37
+ Acopy = convert (Matrix{T}, A)
38
+ else
39
+ Acopy = A
40
+ end
41
+ Ccopy = convert (Matrix{T}, C)
42
+ BLAS. syrk! (uplo, trans, T (alpha), Acopy, T (beta), Ccopy)
43
+ end
44
+ BLAS. syrk! (uplo, trans, alpha, A, beta, C)
45
+ end
46
+ function mixedherk! (uplo, trans, alpha, A, beta, C, StoragePrecision)
47
+ T = StoragePrecision
48
+ if typeof (C) != Matrix{T}
49
+ if typeof (A) != Matrix{T}
50
+ Acopy = convert (Matrix{T}, A)
51
+ else
52
+ Acopy = A
53
+ end
54
+ Ccopy = convert (Matrix{T}, C)
55
+ BLAS. herk! (uplo, trans, T (alpha), Acopy, T (beta), Ccopy)
56
+ end
57
+ BLAS. herk! (uplo, trans, alpha, A, beta, C)
58
+ end
59
+ function MixedPrecisionChol! (A:: DArray{T,2} , :: Type{LowerTriangular} , MP:: Matrix{DataType} ) where T
60
+ LinearAlgebra. checksquare (A)
61
+
62
+ zone = one (T)
63
+ mzone = - one (T)
64
+ rzone = one (real (T))
65
+ rmzone = - one (real (T))
66
+ uplo = ' L'
67
+ Ac = A. chunks
68
+ mt, nt = size (Ac)
69
+ iscomplex = T <: Complex
70
+ trans = iscomplex ? ' C' : ' T'
71
+
72
+
73
+ info = [convert (LinearAlgebra. BlasInt, 0 )]
74
+ try
75
+ Dagger. spawn_datadeps () do
76
+ for k in range (1 , mt)
77
+ Dagger. @spawn potrf_checked! (uplo, InOut (Ac[k, k]), Out (info))
78
+ for m in range (k+ 1 , mt)
79
+ Dagger. @spawn mixedtrsm! (' R' , uplo, trans, ' N' , zone, In (Ac[k, k]), InOut (Ac[m, k]), MP[m,k])
80
+ end
81
+ for n in range (k+ 1 , nt)
82
+ if iscomplex
83
+ Dagger. @spawn mixedherk! (uplo, ' N' , rmzone, In (Ac[n, k]), rzone, InOut (Ac[n, n]), MP[n,n])
84
+ else
85
+ Dagger. @spawn mixedsyrk! (uplo, ' N' , rmzone, In (Ac[n, k]), rzone, InOut (Ac[n, n]), MP[n,n])
86
+ end
87
+ for m in range (n+ 1 , mt)
88
+ Dagger. @spawn mixedgemm! (' N' , trans, mzone, In (Ac[m, k]), In (Ac[n, k]), zone, InOut (Ac[m, n]), MP[m,n])
89
+ end
90
+ end
91
+ end
92
+ end
93
+ catch err
94
+ err isa ThunkFailedException || rethrow ()
95
+ err = Dagger. Sch. unwrap_nested_exception (err. ex)
96
+ err isa PosDefException || rethrow ()
97
+ end
98
+
99
+ return LowerTriangular (A), info[1 ]
100
+ end
101
+
102
+ function MixedPrecisionChol! (A:: DArray{T,2} , :: Type{UpperTriangular} , MP:: Matrix{DataType} ) where T
103
+ LinearAlgebra. checksquare (A)
104
+
105
+ zone = one (T)
106
+ mzone = - one (T)
107
+ rzone = one (real (T))
108
+ rmzone = - one (real (T))
109
+ uplo = ' U'
110
+ Ac = A. chunks
111
+ mt, nt = size (Ac)
112
+ iscomplex = T <: Complex
113
+ trans = iscomplex ? ' C' : ' T'
114
+
115
+ info = [convert (LinearAlgebra. BlasInt, 0 )]
116
+ try
117
+ Dagger. spawn_datadeps () do
118
+ for k in range (1 , mt)
119
+ Dagger. @spawn potrf_checked! (uplo, InOut (Ac[k, k]), Out (info))
120
+ for n in range (k+ 1 , nt)
121
+ Dagger. @spawn mixedtrsm! (' L' , uplo, trans, ' N' , zone, In (Ac[k, k]), InOut (Ac[k, n]), MP[k,n])
122
+ end
123
+ for m in range (k+ 1 , mt)
124
+ if iscomplex
125
+ Dagger. @spawn mixedherk! (uplo, ' C' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
126
+ else
127
+ Dagger. @spawn mixedherk! (uplo, ' T' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
128
+ end
129
+ for n in range (m+ 1 , nt)
130
+ Dagger. @spawn mixedgemm! (trans, ' N' , mzone, In (Ac[k, m]), In (Ac[k, n]), zone, InOut (Ac[m, n]))
131
+ end
132
+ end
133
+ end
134
+ end
135
+ catch err
136
+ err isa ThunkFailedException || rethrow ()
137
+ err = Dagger. Sch. unwrap_nested_exception (err. ex)
138
+ err isa PosDefException || rethrow ()
139
+ end
140
+
141
+ return UpperTriangular (A), info[1 ]
142
+ end
0 commit comments