|
| 1 | +LinearAlgebra.cholcopy(A::DArray{T,2}) where T = copy(A) |
| 2 | +function potrf_checked!(uplo, A, info_arr) |
| 3 | + _A, info = LAPACK.potrf!(uplo, A) |
| 4 | + if info > 0 |
| 5 | + info_arr[1] = info |
| 6 | + throw(PosDefException(info)) |
| 7 | + end |
| 8 | + return _A, info |
| 9 | +end |
| 10 | +function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{UpperTriangular}) where T |
| 11 | + LinearAlgebra.checksquare(A) |
| 12 | + |
| 13 | + zone = one(T) |
| 14 | + mzone = -one(T) |
| 15 | + rzone = one(real(T)) |
| 16 | + rmzone = -one(real(T)) |
| 17 | + uplo = 'U' |
| 18 | + Ac = A.chunks |
| 19 | + mt, nt = size(Ac) |
| 20 | + iscomplex = T <: Complex |
| 21 | + trans = iscomplex ? 'C' : 'T' |
| 22 | + |
| 23 | + info = [convert(LinearAlgebra.BlasInt, 0)] |
| 24 | + try |
| 25 | + Dagger.spawn_datadeps() do |
| 26 | + for k in range(1, mt) |
| 27 | + Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info)) |
| 28 | + for n in range(k+1, nt) |
| 29 | + Dagger.@spawn BLAS.trsm!('L', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[k, n])) |
| 30 | + end |
| 31 | + for m in range(k+1, mt) |
| 32 | + if iscomplex |
| 33 | + Dagger.@spawn BLAS.herk!(uplo, 'C', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) |
| 34 | + else |
| 35 | + Dagger.@spawn BLAS.syrk!(uplo, 'T', rmzone, In(Ac[k, m]), rzone, InOut(Ac[m, m])) |
| 36 | + end |
| 37 | + for n in range(m+1, nt) |
| 38 | + Dagger.@spawn BLAS.gemm!(trans, 'N', mzone, In(Ac[k, m]), In(Ac[k, n]), zone, InOut(Ac[m, n])) |
| 39 | + end |
| 40 | + end |
| 41 | + end |
| 42 | + end |
| 43 | + catch err |
| 44 | + err isa ThunkFailedException || rethrow() |
| 45 | + err = Dagger.Sch.unwrap_nested_exception(err.ex) |
| 46 | + err isa PosDefException || rethrow() |
| 47 | + end |
| 48 | + |
| 49 | + return UpperTriangular(A), info[1] |
| 50 | +end |
| 51 | +function LinearAlgebra._chol!(A::DArray{T,2}, ::Type{LowerTriangular}) where T |
| 52 | + LinearAlgebra.checksquare(A) |
| 53 | + |
| 54 | + zone = one(T) |
| 55 | + mzone = -one(T) |
| 56 | + rzone = one(real(T)) |
| 57 | + rmzone = -one(real(T)) |
| 58 | + uplo = 'L' |
| 59 | + Ac = A.chunks |
| 60 | + mt, nt = size(Ac) |
| 61 | + iscomplex = T <: Complex |
| 62 | + trans = iscomplex ? 'C' : 'T' |
| 63 | + |
| 64 | + info = [convert(LinearAlgebra.BlasInt, 0)] |
| 65 | + try |
| 66 | + Dagger.spawn_datadeps() do |
| 67 | + for k in range(1, mt) |
| 68 | + Dagger.@spawn potrf_checked!(uplo, InOut(Ac[k, k]), Out(info)) |
| 69 | + for m in range(k+1, mt) |
| 70 | + Dagger.@spawn BLAS.trsm!('R', uplo, trans, 'N', zone, In(Ac[k, k]), InOut(Ac[m, k])) |
| 71 | + end |
| 72 | + for n in range(k+1, nt) |
| 73 | + if iscomplex |
| 74 | + Dagger.@spawn BLAS.herk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n])) |
| 75 | + else |
| 76 | + Dagger.@spawn BLAS.syrk!(uplo, 'N', rmzone, In(Ac[n, k]), rzone, InOut(Ac[n, n])) |
| 77 | + end |
| 78 | + for m in range(n+1, mt) |
| 79 | + Dagger.@spawn BLAS.gemm!('N', trans, mzone, In(Ac[m, k]), In(Ac[n, k]), zone, InOut(Ac[m, n])) |
| 80 | + end |
| 81 | + end |
| 82 | + end |
| 83 | + end |
| 84 | + catch err |
| 85 | + err isa ThunkFailedException || rethrow() |
| 86 | + err = Dagger.Sch.unwrap_nested_exception(err.ex) |
| 87 | + err isa PosDefException || rethrow() |
| 88 | + end |
| 89 | + |
| 90 | + return LowerTriangular(A), info[1] |
| 91 | +end |
0 commit comments