Skip to content

Commit 3a9cc4f

Browse files
authored
Switch to Runic formatter (#58)
* Runic formatter * Formatter action * Remove .JuliaFormatter
1 parent 3871df1 commit 3a9cc4f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1837
-1384
lines changed

.JuliaFormatter.toml

Lines changed: 0 additions & 1 deletion
This file was deleted.

.github/workflows/FormatCheck.yml

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,14 @@ on:
44
push:
55
branches:
66
- 'main'
7+
- 'master'
78
tags: '*'
89
pull_request:
910
branches:
1011
- 'main'
11-
jobs:
12-
build:
13-
runs-on: ${{ matrix.os }}
14-
strategy:
15-
matrix:
16-
version:
17-
- '1' # automatically expands to the latest stable 1.x release of Julia
18-
os:
19-
- ubuntu-latest
20-
arch:
21-
- x64
22-
steps:
23-
- uses: julia-actions/setup-julia@latest
24-
with:
25-
version: ${{ matrix.version }}
26-
arch: ${{ matrix.arch }}
12+
- 'master'
2713

28-
- uses: actions/checkout@v5
29-
- name: Install JuliaFormatter and format
30-
run: |
31-
julia -e 'using Pkg; Pkg.add(PackageSpec(name="JuliaFormatter"))'
32-
julia -e 'using JuliaFormatter; format(".", verbose=true)'
14+
jobs:
15+
formatcheck:
16+
name: "Format Check"
17+
uses: "QuantumKitHub/QuantumKitHubActions/.github/workflows/FormatCheck.yml@main"

docs/make.jl

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
if Base.active_project() != joinpath(@__DIR__, "Project.toml")
33
using Pkg
44
Pkg.activate(@__DIR__)
5-
Pkg.develop(PackageSpec(; path=joinpath(@__DIR__, "..")))
5+
Pkg.develop(PackageSpec(; path = joinpath(@__DIR__, "..")))
66
Pkg.resolve()
77
Pkg.instantiate()
88
end
@@ -24,28 +24,42 @@ plugins = []
2424
# "BlockTensorKit" => "https://lkdvos.github.io/BlockTensorKit.jl/dev/")
2525
# push!(plugins, links)
2626

27-
DocMeta.setdocmeta!(MatrixAlgebraKit, :DocTestSetup, :(using MatrixAlgebraKit);
28-
recursive=true)
27+
DocMeta.setdocmeta!(
28+
MatrixAlgebraKit, :DocTestSetup, :(using MatrixAlgebraKit);
29+
recursive = true
30+
)
2931

30-
mathengine = MathJax3(Dict(:loader => Dict("load" => ["[tex]/physics"]),
31-
:tex => Dict("inlineMath" => [["\$", "\$"], ["\\(", "\\)"]],
32-
"tags" => "ams",
33-
"packages" => ["base", "ams", "autoload", "physics"])))
32+
mathengine = MathJax3(
33+
Dict(
34+
:loader => Dict("load" => ["[tex]/physics"]),
35+
:tex => Dict(
36+
"inlineMath" => [["\$", "\$"], ["\\(", "\\)"]],
37+
"tags" => "ams",
38+
"packages" => ["base", "ams", "autoload", "physics"]
39+
)
40+
)
41+
)
3442
makedocs(;
35-
sitename="MatrixAlgebraKit.jl",
36-
format=Documenter.HTML(;
37-
prettyurls=get(ENV, "CI", nothing) == "true",
38-
mathengine,
39-
size_threshold=512000),
40-
pages=["Home" => "index.md",
41-
"User Interface" => ["user_interface/compositions.md",
42-
"user_interface/decompositions.md",
43-
"user_interface/truncations.md",
44-
"user_interface/matrix_functions.md"],
45-
"Developer Interface" => "dev_interface.md",
46-
"Library" => "library.md"],
47-
checkdocs=:exports,
48-
doctest=true,
49-
plugins)
43+
sitename = "MatrixAlgebraKit.jl",
44+
format = Documenter.HTML(;
45+
prettyurls = get(ENV, "CI", nothing) == "true",
46+
mathengine,
47+
size_threshold = 512000
48+
),
49+
pages = [
50+
"Home" => "index.md",
51+
"User Interface" => [
52+
"user_interface/compositions.md",
53+
"user_interface/decompositions.md",
54+
"user_interface/truncations.md",
55+
"user_interface/matrix_functions.md",
56+
],
57+
"Developer Interface" => "dev_interface.md",
58+
"Library" => "library.md",
59+
],
60+
checkdocs = :exports,
61+
doctest = true,
62+
plugins
63+
)
5064

51-
deploydocs(; repo="github.com/QuantumKitHub/MatrixAlgebraKit.jl.git", push_preview=true)
65+
deploydocs(; repo = "github.com/QuantumKitHub/MatrixAlgebraKit.jl.git", push_preview = true)

ext/MatrixAlgebraKitAMDGPUExt/MatrixAlgebraKitAMDGPUExt.jl

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,39 @@ using LinearAlgebra: BlasFloat
1414

1515
include("yarocsolver.jl")
1616

17-
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
17+
function MatrixAlgebraKit.default_qr_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
1818
return ROCSOLVER_HouseholderQR(; kwargs...)
1919
end
20-
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
20+
function MatrixAlgebraKit.default_lq_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2121
qr_alg = ROCSOLVER_HouseholderQR(; kwargs...)
2222
return LQViaTransposedQR(qr_alg)
2323
end
24-
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
24+
function MatrixAlgebraKit.default_svd_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2525
return ROCSOLVER_QRIteration(; kwargs...)
2626
end
27-
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T<:StridedROCMatrix}
27+
function MatrixAlgebraKit.default_eigh_algorithm(::Type{T}; kwargs...) where {T <: StridedROCMatrix}
2828
return ROCSOLVER_DivideAndConquer(; kwargs...)
2929
end
3030

3131
_gpu_geqrf!(A::StridedROCMatrix) = YArocSOLVER.geqrf!(A)
3232
_gpu_ungqr!(A::StridedROCMatrix, τ::StridedROCVector) = YArocSOLVER.ungqr!(A, τ)
33-
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) = YArocSOLVER.unmqr!(side, trans, A, τ, C)
34-
_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) = YArocSOLVER.gesvd!(A, S, U, Vᴴ)
33+
_gpu_unmqr!(side::AbstractChar, trans::AbstractChar, A::StridedROCMatrix, τ::StridedROCVector, C::StridedROCVecOrMat) =
34+
YArocSOLVER.unmqr!(side, trans, A, τ, C)
35+
_gpu_gesvd!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix) =
36+
YArocSOLVER.gesvd!(A, S, U, Vᴴ)
3537
# not yet supported
36-
#_gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
37-
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) = YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
38-
39-
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevj!(A, Dd, V; kwargs...)
40-
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevd!(A, Dd, V; kwargs...)
41-
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heev!(A, Dd, V; kwargs...)
42-
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) = YArocSOLVER.heevx!(A, Dd, V; kwargs...)
38+
# _gpu_Xgesvdp!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
39+
# YArocSOLVER.Xgesvdp!(A, S, U, Vᴴ; kwargs...)
40+
_gpu_gesvdj!(A::StridedROCMatrix, S::StridedROCVector, U::StridedROCMatrix, Vᴴ::StridedROCMatrix; kwargs...) =
41+
YArocSOLVER.gesvdj!(A, S, U, Vᴴ; kwargs...)
42+
_gpu_heevj!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
43+
YArocSOLVER.heevj!(A, Dd, V; kwargs...)
44+
_gpu_heevd!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
45+
YArocSOLVER.heevd!(A, Dd, V; kwargs...)
46+
_gpu_heev!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
47+
YArocSOLVER.heev!(A, Dd, V; kwargs...)
48+
_gpu_heevx!(A::StridedROCMatrix, Dd::StridedROCVector, V::StridedROCMatrix; kwargs...) =
49+
YArocSOLVER.heevx!(A, Dd, V; kwargs...)
4350

4451
function MatrixAlgebraKit.findtruncated_svd(values::StridedROCVector, strategy::TruncationByValue)
4552
return MatrixAlgebraKit.findtruncated(values, strategy)

ext/MatrixAlgebraKitAMDGPUExt/yarocsolver.jl

Lines changed: 82 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -16,17 +16,19 @@ const ungqr! = orgqr!
1616

1717
# Wrapper for SVD via QR Iteration
1818
for (fname, elty, relty) in
19-
((:rocsolver_sgesvd, :Float32, :Float32),
20-
(:rocsolver_dgesvd, :Float64, :Float64),
21-
(:rocsolver_cgesvd, :ComplexF32, :Float32),
22-
(:rocsolver_zgesvd, :ComplexF64, :Float64))
19+
(
20+
(:rocsolver_sgesvd, :Float32, :Float32),
21+
(:rocsolver_dgesvd, :Float64, :Float64),
22+
(:rocsolver_cgesvd, :ComplexF32, :Float32),
23+
(:rocsolver_zgesvd, :ComplexF64, :Float64),
24+
)
2325
@eval begin
24-
#! format: off
25-
function gesvd!(A::StridedROCMatrix{$elty},
26-
S::StridedROCVector{$relty}=similar(A, $relty, min(size(A)...)),
27-
U::StridedROCMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
28-
Vᴴ::StridedROCMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2)))
29-
#! format: on
26+
function gesvd!(
27+
A::StridedROCMatrix{$elty},
28+
S::StridedROCVector{$relty} = similar(A, $relty, min(size(A)...)),
29+
U::StridedROCMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
30+
Vᴴ::StridedROCMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2))
31+
)
3032
chkstride1(A, U, Vᴴ, S)
3133
m, n = size(A)
3234
(m < n) && throw(ArgumentError("rocSOLVER's gesvd requires m ≥ n"))
@@ -72,13 +74,15 @@ for (fname, elty, relty) in
7274
ldu = max(1, stride(U, 2))
7375
ldv = max(1, stride(Vᴴ, 2))
7476

75-
rwork = ROCArray{$relty}(undef, minmn - 1)
76-
dh = rocBLAS.handle()
77+
rwork = ROCArray{$relty}(undef, minmn - 1)
78+
dh = rocBLAS.handle()
7779
dev_info = ROCVector{Cint}(undef, 1)
78-
rocSOLVER.$fname(dh, jobu, jobvt, m, n,
79-
A, lda, S, U, ldu, Vᴴ, ldv,
80-
rwork, convert(rocSOLVER.rocblas_workmode, 'I'),
81-
dev_info)
80+
rocSOLVER.$fname(
81+
dh, jobu, jobvt, m, n,
82+
A, lda, S, U, ldu, Vᴴ, ldv,
83+
rwork, convert(rocSOLVER.rocblas_workmode, 'I'),
84+
dev_info
85+
)
8286
AMDGPU.unsafe_free!(rwork)
8387

8488
info = @allowscalar dev_info[1]
@@ -91,20 +95,21 @@ end
9195

9296
# Wrapper for SVD via Jacobi
9397
for (fname, elty, relty) in
94-
((:rocsolver_sgesvdj, :Float32, :Float32),
95-
(:rocsolver_dgesvdj, :Float64, :Float64),
96-
(:rocsolver_cgesvdj, :ComplexF32, :Float32),
97-
(:rocsolver_zgesvdj, :ComplexF64, :Float64))
98+
(
99+
(:rocsolver_sgesvdj, :Float32, :Float32),
100+
(:rocsolver_dgesvdj, :Float64, :Float64),
101+
(:rocsolver_cgesvdj, :ComplexF32, :Float32),
102+
(:rocsolver_zgesvdj, :ComplexF64, :Float64),
103+
)
98104
@eval begin
99-
#! format: off
100-
function gesvdj!(A::StridedROCMatrix{$elty},
101-
S::StridedROCVector{$relty}=similar(A, $relty, min(size(A)...)),
102-
U::StridedROCMatrix{$elty}=similar(A, $elty, size(A, 1), min(size(A)...)),
103-
Vᴴ::StridedROCMatrix{$elty}=similar(A, $elty, min(size(A)...), size(A, 2));
104-
tol::$relty=eps($relty),
105-
max_sweeps::Int=100,
106-
)
107-
#! format: on
105+
function gesvdj!(
106+
A::StridedROCMatrix{$elty},
107+
S::StridedROCVector{$relty} = similar(A, $relty, min(size(A)...)),
108+
U::StridedROCMatrix{$elty} = similar(A, $elty, size(A, 1), min(size(A)...)),
109+
Vᴴ::StridedROCMatrix{$elty} = similar(A, $elty, min(size(A)...), size(A, 2));
110+
tol::$relty = eps($relty),
111+
max_sweeps::Int = 100,
112+
)
108113
chkstride1(A, U, Vᴴ, S)
109114
m, n = size(A)
110115
minmn = min(m, n)
@@ -149,21 +154,22 @@ for (fname, elty, relty) in
149154
lda = max(1, stride(A, 2))
150155
ldu = max(1, stride(U, 2))
151156
ldv = max(1, stride(Vᴴ, 2))
152-
dev_info = ROCVector{Cint}(undef, 1)
157+
dev_info = ROCVector{Cint}(undef, 1)
153158
dev_residual = ROCVector{$relty}(undef, 1)
154159
dev_n_sweeps = ROCVector{Cint}(undef, 1)
155160

156161
dh = rocBLAS.handle()
157-
rocSOLVER.$fname(dh, jobu, jobvt, m, n, A, lda, tol,
158-
dev_residual, max_sweeps, dev_n_sweeps,
159-
S, U, ldu, Vᴴ, ldv, dev_info,
160-
)
162+
rocSOLVER.$fname(
163+
dh, jobu, jobvt, m, n, A, lda, tol,
164+
dev_residual, max_sweeps, dev_n_sweeps,
165+
S, U, ldu, Vᴴ, ldv, dev_info,
166+
)
161167

162168
info = @allowscalar dev_info[1]
163169
rocSOLVER.chkargsok(BlasInt(info))
164170

165-
AMDGPU.unsafe_free!(dev_residual)
166-
AMDGPU.unsafe_free!(dev_n_sweeps)
171+
AMDGPU.unsafe_free!(dev_residual)
172+
AMDGPU.unsafe_free!(dev_n_sweeps)
167173
return (S, U, Vᴴ)
168174
end
169175
end
@@ -476,15 +482,19 @@ end
476482
# end
477483

478484
for (heevd, heev, heevx, heevj, elty, relty) in
479-
((:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32),
480-
(:(rocSOLVER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64),
481-
(:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32),
482-
(:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64))
485+
(
486+
(:(rocSOLVER.rocsolver_ssyevd), :(rocSOLVER.rocsolver_ssyev), :(rocSOLVER.rocsolver_ssyevx), :(rocSOLVER.rocsolver_ssyevj), :Float32, :Float32),
487+
(:(rocSOLVER.rocsolver_dsyevd), :(rocSOLVER.rocsolver_dsyev), :(rocSOLVER.rocsolver_dsyevx), :(rocSOLVER.rocsolver_dsyevj), :Float64, :Float64),
488+
(:(rocSOLVER.rocsolver_cheevd), :(rocSOLVER.rocsolver_cheev), :(rocSOLVER.rocsolver_cheevx), :(rocSOLVER.rocsolver_cheevj), :ComplexF32, :Float32),
489+
(:(rocSOLVER.rocsolver_zheevd), :(rocSOLVER.rocsolver_zheev), :(rocSOLVER.rocsolver_zheevx), :(rocSOLVER.rocsolver_zheevj), :ComplexF64, :Float64),
490+
)
483491
@eval begin
484-
function heevd!(A::StridedROCMatrix{$elty},
485-
W::StridedROCVector{$relty},
486-
V::StridedROCMatrix{$elty};
487-
uplo::Char='U')
492+
function heevd!(
493+
A::StridedROCMatrix{$elty},
494+
W::StridedROCVector{$relty},
495+
V::StridedROCMatrix{$elty};
496+
uplo::Char = 'U'
497+
)
488498
chkuplo(uplo)
489499
n = checksquare(A)
490500
lda = max(1, stride(A, 2))
@@ -509,10 +519,12 @@ for (heevd, heev, heevx, heevj, elty, relty) in
509519
end
510520
return W, V
511521
end
512-
function heev!(A::StridedROCMatrix{$elty},
513-
W::StridedROCVector{$relty},
514-
V::StridedROCMatrix{$elty};
515-
uplo::Char='U')
522+
function heev!(
523+
A::StridedROCMatrix{$elty},
524+
W::StridedROCVector{$relty},
525+
V::StridedROCMatrix{$elty};
526+
uplo::Char = 'U'
527+
)
516528
chkuplo(uplo)
517529
n = checksquare(A)
518530
lda = max(1, stride(A, 2))
@@ -537,11 +549,13 @@ for (heevd, heev, heevx, heevj, elty, relty) in
537549
end
538550
return W, V
539551
end
540-
function heevx!(A::StridedROCMatrix{$elty},
541-
W::StridedROCVector{$relty},
542-
V::StridedROCMatrix{$elty};
543-
uplo::Char='U',
544-
kwargs...)
552+
function heevx!(
553+
A::StridedROCMatrix{$elty},
554+
W::StridedROCVector{$relty},
555+
V::StridedROCMatrix{$elty};
556+
uplo::Char = 'U',
557+
kwargs...
558+
)
545559
chkuplo(uplo)
546560
n = checksquare(A)
547561
lda = max(1, stride(A, 2))
@@ -567,27 +581,29 @@ for (heevd, heev, heevx, heevj, elty, relty) in
567581
size(V) == (n, n) || throw(DimensionMismatch("size mismatch between A and V"))
568582
jobz = rocSOLVER.rocblas_evect_original
569583
end
570-
dh = rocBLAS.handle()
571-
abstol = -one($relty)
572-
nev = ROCVector{Cint}(undef, 1)
573-
ldv = max(1, stride(V, 2))
574-
ifail = ROCVector{Cint}(undef, n)
584+
dh = rocBLAS.handle()
585+
abstol = -one($relty)
586+
nev = ROCVector{Cint}(undef, 1)
587+
ldv = max(1, stride(V, 2))
588+
ifail = ROCVector{Cint}(undef, n)
575589
dev_info = ROCVector{Cint}(undef, 1)
576590
roc_uplo = convert(rocSOLVER.rocblas_fill, uplo)
577591
$heevx(dh, jobz, range, roc_uplo, n, A, lda, vl, vu, il, iu, abstol, nev, W, V, ldv, ifail, dev_info)
578592

579593
info = @allowscalar dev_info[1]
580594
chkargsok(BlasInt(info))
581-
m = @allowscalar nev[1]
595+
m = @allowscalar nev[1]
582596
return W, V, m
583597
end
584-
function heevj!(A::StridedROCMatrix{$elty},
585-
W::StridedROCVector{$relty},
586-
V::StridedROCMatrix{$elty};
587-
uplo::Char='U',
588-
tol::$relty=eps($relty),
589-
max_sweeps::Int=100,
590-
sort::Char='N')
598+
function heevj!(
599+
A::StridedROCMatrix{$elty},
600+
W::StridedROCVector{$relty},
601+
V::StridedROCMatrix{$elty};
602+
uplo::Char = 'U',
603+
tol::$relty = eps($relty),
604+
max_sweeps::Int = 100,
605+
sort::Char = 'N'
606+
)
591607
chkuplo(uplo)
592608
n = checksquare(A)
593609
lda = max(1, stride(A, 2))

0 commit comments

Comments
 (0)