Skip to content

Commit 751736e

Browse files
committed
mixed precision cholesky with copy overhead
1 parent 9fbb981 commit 751736e

File tree

3 files changed

+144
-1
lines changed

3 files changed

+144
-1
lines changed

src/Dagger.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ include("array/linalg.jl")
8181
include("array/mul.jl")
8282
include("array/cholesky.jl")
8383
include("array/adapt_precision.jl")
84+
include("array/mixchol.jl")
8485
# Visualization
8586
include("visualization.jl")
8687
include("ui/gantt-common.jl")

src/array/adapt_precision.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ function adapt_precision(A::DArray{T,2}, tolerance::T) where {T}
157157

158158
global_norm = LinearAlgebra.norm2(A)
159159

160-
MP = fill("Float64", mt, nt)
160+
MP = fill(T, mt, nt)
161161
DMP = view(MP, Blocks(1, 1))
162162
MPc = DMP.chunks
163163

src/array/mixchol.jl

Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
function mixedtrsm!(side, uplo, trans, diag, alpha, A, B, StoragePrecision)
2+
T = StoragePrecision
3+
if typeof(B) != Matrix{T}
4+
println("B is not of type $T but of type $(typeof(B))")
5+
if typeof(A) != Matrix{T}
6+
Acopy = convert(Matrix{T}, A)
7+
else
8+
Acopy = A
9+
end
10+
Bcopy = convert(Matrix{T}, B)
11+
BLAS.trsm!(side, uplo, trans, diag, T(alpha), Acopy, Bcopy)
12+
end
13+
BLAS.trsm!(side, uplo, trans, diag, alpha, A, B)
14+
end
15+
function mixedgemm!(transa, transb, alpha, A, B, beta, C, StoragePrecision)
16+
T = StoragePrecision
17+
if typeof(C) != Matrix{T}
18+
if typeof(A) != Matrix{T}
19+
Acopy = convert(Matrix{T}, A)
20+
else
21+
Acopy = A
22+
end
23+
if typeof(B) != Matrix{T}
24+
Bcopy = convert(Matrix{T}, B)
25+
else
26+
Bcopy = B
27+
end
28+
Ccopy = convert(Matrix{T}, C)
29+
BLAS.gemm!(transa, transb, T(alpha), Acopy, Bcopy, T(beta), Ccopy)
30+
end
31+
BLAS.gemm!(transa, transb, alpha, A, B, beta, C)
32+
end
33+
function mixedsyrk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
34+
T = StoragePrecision
35+
if typeof(C) != Matrix{T}
36+
if typeof(A) != Matrix{T}
37+
Acopy = convert(Matrix{T}, A)
38+
else
39+
Acopy = A
40+
end
41+
Ccopy = convert(Matrix{T}, C)
42+
BLAS.syrk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy)
43+
end
44+
BLAS.syrk!(uplo, trans, alpha, A, beta, C)
45+
end
46+
function mixedherk!(uplo, trans, alpha, A, beta, C, StoragePrecision)
47+
T = StoragePrecision
48+
if typeof(C) != Matrix{T}
49+
if typeof(A) != Matrix{T}
50+
Acopy = convert(Matrix{T}, A)
51+
else
52+
Acopy = A
53+
end
54+
Ccopy = convert(Matrix{T}, C)
55+
BLAS.herk!(uplo, trans, T(alpha), Acopy, T(beta), Ccopy)
56+
end
57+
BLAS.herk!(uplo, trans, alpha, A, beta, C)
58+
end
59+
function MixedPrecisionChol!(A::DArray{T,2}, ::Type{LowerTriangular}, MP::Matrix{DataType}) where T
60+
LinearAlgebra.checksquare(A)
61+
62+
zone = one(T)
63+
mzone = -one(T)
64+
rzone = one(real(T))
65+
rmzone = -one(real(T))
66+
uplo = 'L'
67+
Ac = A.chunks
68+
mt, nt = size(Ac)
69+
iscomplex = T <: Complex
70+
trans = iscomplex ? 'C' : 'T'
71+
72+
73+
info = [convert(LinearAlgebra.BlasInt, 0)]
74+
try
75+
Dagger.spawn_datadeps() do
76+
for k in range(1, mt)
77+
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
78+
for m in range(k+1, mt)
79+
Dagger.@spawn mixedtrsm!('R', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[m, k]), MP[m,k])
80+
end
81+
for n in range(k+1, nt)
82+
if iscomplex
83+
Dagger.@spawn mixedherk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]), MP[n,n])
84+
else
85+
Dagger.@spawn mixedsyrk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n]), MP[n,n])
86+
end
87+
for m in range(n+1, mt)
88+
Dagger.@spawn mixedgemm!('N', trans, mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n]), MP[m,n])
89+
end
90+
end
91+
end
92+
end
93+
catch err
94+
err isa ThunkFailedException || rethrow()
95+
err = Dagger.Sch.unwrap_nested_exception(err.ex)
96+
err isa PosDefException || rethrow()
97+
end
98+
99+
return LowerTriangular(A), info[1]
100+
end
101+
102+
function MixedPrecisionChol!(A::DArray{T,2}, ::Type{UpperTriangular}, MP::Matrix{DataType}) where T
103+
LinearAlgebra.checksquare(A)
104+
105+
zone = one(T)
106+
mzone = -one(T)
107+
rzone = one(real(T))
108+
rmzone = -one(real(T))
109+
uplo = 'U'
110+
Ac = A.chunks
111+
mt, nt = size(Ac)
112+
iscomplex = T <: Complex
113+
trans = iscomplex ? 'C' : 'T'
114+
115+
info = [convert(LinearAlgebra.BlasInt, 0)]
116+
try
117+
Dagger.spawn_datadeps() do
118+
for k in range(1, mt)
119+
Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info))
120+
for n in range(k+1, nt)
121+
Dagger.@spawn mixedtrsm!('L', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[k, n]), MP[k,n])
122+
end
123+
for m in range(k+1, mt)
124+
if iscomplex
125+
Dagger.@spawn mixedherk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
126+
else
127+
Dagger.@spawn mixedherk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m]))
128+
end
129+
for n in range(m+1, nt)
130+
Dagger.@spawn mixedgemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n]))
131+
end
132+
end
133+
end
134+
end
135+
catch err
136+
err isa ThunkFailedException || rethrow()
137+
err = Dagger.Sch.unwrap_nested_exception(err.ex)
138+
err isa PosDefException || rethrow()
139+
end
140+
141+
return UpperTriangular(A), info[1]
142+
end

0 commit comments

Comments
 (0)