1+ function mixedtrsm! (side, uplo, trans, diag, alpha, A, B, StoragePrecision)
2+ T = StoragePrecision
3+ if typeof (B) != Matrix{T}
4+ println (" B is not of type $T but of type $(typeof (B)) " )
5+ if typeof (A) != Matrix{T}
6+ Acopy = convert (Matrix{T}, A)
7+ else
8+ Acopy = A
9+ end
10+ Bcopy = convert (Matrix{T}, B)
11+ BLAS. trsm! (side, uplo, trans, diag, T (alpha), Acopy, Bcopy)
12+ end
13+ BLAS. trsm! (side, uplo, trans, diag, alpha, A, B)
14+ end
15+ function mixedgemm! (transa, transb, alpha, A, B, beta, C, StoragePrecision)
16+ T = StoragePrecision
17+ if typeof (C) != Matrix{T}
18+ if typeof (A) != Matrix{T}
19+ Acopy = convert (Matrix{T}, A)
20+ else
21+ Acopy = A
22+ end
23+ if typeof (B) != Matrix{T}
24+ Bcopy = convert (Matrix{T}, B)
25+ else
26+ Bcopy = B
27+ end
28+ Ccopy = convert (Matrix{T}, C)
29+ BLAS. gemm! (transa, transb, T (alpha), Acopy, Bcopy, T (beta), Ccopy)
30+ end
31+ BLAS. gemm! (transa, transb, alpha, A, B, beta, C)
32+ end
33+ function mixedsyrk! (uplo, trans, alpha, A, beta, C, StoragePrecision)
34+ T = StoragePrecision
35+ if typeof (C) != Matrix{T}
36+ if typeof (A) != Matrix{T}
37+ Acopy = convert (Matrix{T}, A)
38+ else
39+ Acopy = A
40+ end
41+ Ccopy = convert (Matrix{T}, C)
42+ BLAS. syrk! (uplo, trans, T (alpha), Acopy, T (beta), Ccopy)
43+ end
44+ BLAS. syrk! (uplo, trans, alpha, A, beta, C)
45+ end
46+ function mixedherk! (uplo, trans, alpha, A, beta, C, StoragePrecision)
47+ T = StoragePrecision
48+ if typeof (C) != Matrix{T}
49+ if typeof (A) != Matrix{T}
50+ Acopy = convert (Matrix{T}, A)
51+ else
52+ Acopy = A
53+ end
54+ Ccopy = convert (Matrix{T}, C)
55+ BLAS. herk! (uplo, trans, T (alpha), Acopy, T (beta), Ccopy)
56+ end
57+ BLAS. herk! (uplo, trans, alpha, A, beta, C)
58+ end
59+ function MixedPrecisionChol! (A:: DArray{T,2} , :: Type{LowerTriangular} , MP:: Matrix{DataType} ) where T
60+ LinearAlgebra. checksquare (A)
61+
62+ zone = one (T)
63+ mzone = - one (T)
64+ rzone = one (real (T))
65+ rmzone = - one (real (T))
66+ uplo = ' L'
67+ Ac = A. chunks
68+ mt, nt = size (Ac)
69+ iscomplex = T <: Complex
70+ trans = iscomplex ? ' C' : ' T'
71+
72+
73+ info = [convert (LinearAlgebra. BlasInt, 0 )]
74+ try
75+ Dagger. spawn_datadeps () do
76+ for k in range (1 , mt)
77+ Dagger. @spawn potrf_checked! (uplo, InOut (Ac[k, k]), Out (info))
78+ for m in range (k+ 1 , mt)
79+ Dagger. @spawn mixedtrsm! (' R' , uplo, trans, ' N' , zone, In (Ac[k, k]), InOut (Ac[m, k]), MP[m,k])
80+ end
81+ for n in range (k+ 1 , nt)
82+ if iscomplex
83+ Dagger. @spawn mixedherk! (uplo, ' N' , rmzone, In (Ac[n, k]), rzone, InOut (Ac[n, n]), MP[n,n])
84+ else
85+ Dagger. @spawn mixedsyrk! (uplo, ' N' , rmzone, In (Ac[n, k]), rzone, InOut (Ac[n, n]), MP[n,n])
86+ end
87+ for m in range (n+ 1 , mt)
88+ Dagger. @spawn mixedgemm! (' N' , trans, mzone, In (Ac[m, k]), In (Ac[n, k]), zone, InOut (Ac[m, n]), MP[m,n])
89+ end
90+ end
91+ end
92+ end
93+ catch err
94+ err isa ThunkFailedException || rethrow ()
95+ err = Dagger. Sch. unwrap_nested_exception (err. ex)
96+ err isa PosDefException || rethrow ()
97+ end
98+
99+ return LowerTriangular (A), info[1 ]
100+ end
101+
102+ function MixedPrecisionChol! (A:: DArray{T,2} , :: Type{UpperTriangular} , MP:: Matrix{DataType} ) where T
103+ LinearAlgebra. checksquare (A)
104+
105+ zone = one (T)
106+ mzone = - one (T)
107+ rzone = one (real (T))
108+ rmzone = - one (real (T))
109+ uplo = ' U'
110+ Ac = A. chunks
111+ mt, nt = size (Ac)
112+ iscomplex = T <: Complex
113+ trans = iscomplex ? ' C' : ' T'
114+
115+ info = [convert (LinearAlgebra. BlasInt, 0 )]
116+ try
117+ Dagger. spawn_datadeps () do
118+ for k in range (1 , mt)
119+ Dagger. @spawn potrf_checked! (uplo, InOut (Ac[k, k]), Out (info))
120+ for n in range (k+ 1 , nt)
121+ Dagger. @spawn mixedtrsm! (' L' , uplo, trans, ' N' , zone, In (Ac[k, k]), InOut (Ac[k, n]), MP[k,n])
122+ end
123+ for m in range (k+ 1 , mt)
124+ if iscomplex
125+ Dagger. @spawn mixedherk! (uplo, ' C' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
126+ else
127+ Dagger. @spawn mixedherk! (uplo, ' T' , rmzone, In (Ac[k, m]), rzone, InOut (Ac[m, m]))
128+ end
129+ for n in range (m+ 1 , nt)
130+ Dagger. @spawn mixedgemm! (trans, ' N' , mzone, In (Ac[k, m]), In (Ac[k, n]), zone, InOut (Ac[m, n]))
131+ end
132+ end
133+ end
134+ end
135+ catch err
136+ err isa ThunkFailedException || rethrow ()
137+ err = Dagger. Sch. unwrap_nested_exception (err. ex)
138+ err isa PosDefException || rethrow ()
139+ end
140+
141+ return UpperTriangular (A), info[1 ]
142+ end
0 commit comments