Skip to content

Commit c8104cf

Browse files
committed
reapplying formatting
1 parent 763ad4f commit c8104cf

38 files changed

+718
-649
lines changed

.JuliaFormatter.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
style = "sciml"
22
format_markdown = true
3+
format_docstrings = true
34
annotate_untyped_fields_with_any = false

benchmarks/applelu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ algs = [
2323
GenericLUFactorization(),
2424
RFLUFactorization(),
2525
AppleAccelerateLUFactorization(),
26-
MetalLUFactorization(),
26+
MetalLUFactorization()
2727
]
2828
res = [Float32[] for i in 1:length(algs)]
2929

benchmarks/lu.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ algs = [
2424
RFLUFactorization(),
2525
MKLLUFactorization(),
2626
FastLUFactorization(),
27-
SimpleLUFactorization(),
27+
SimpleLUFactorization()
2828
]
2929
res = [Float64[] for i in 1:length(algs)]
3030

benchmarks/sparselu.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ algs = [
3636
UMFPACKFactorization(),
3737
KLUFactorization(),
3838
MKLPardisoFactorize(),
39-
SparspakFactorization(),
39+
SparspakFactorization()
4040
]
4141
cols = [:red, :blue, :green, :magenta, :turqoise] # one color per alg
4242
lst = [:dash, :solid, :dashdot] # one line style per dim
@@ -65,7 +65,8 @@ function run_and_plot(; dims = [1, 2, 3], kmax = 12)
6565
u0 = rand(rng, n)
6666

6767
for j in 1:length(algs)
68-
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(copy($A),
68+
bt = @belapsed solve(prob, $(algs[j])).u setup=(prob = LinearProblem(
69+
copy($A),
6970
copy($b);
7071
u0 = copy($u0),
7172
alias_A = true,

docs/pages.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
pages = ["index.md",
44
"Tutorials" => Any["tutorials/linear.md"
5-
"tutorials/caching_interface.md"],
5+
"tutorials/caching_interface.md"],
66
"Basics" => Any["basics/LinearProblem.md",
77
"basics/common_solver_opts.md",
88
"basics/OperatorAssumptions.md",
99
"basics/Preconditioners.md",
1010
"basics/FAQ.md"],
1111
"Solvers" => Any["solvers/solvers.md"],
1212
"Advanced" => Any["advanced/developing.md"
13-
"advanced/custom.md"],
14-
"Release Notes" => "release_notes.md",
13+
"advanced/custom.md"],
14+
"Release Notes" => "release_notes.md"
1515
]

docs/src/advanced/developing.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ basic machinery. A simplified version is:
1818
struct MyLUFactorization{P} <: SciMLBase.AbstractLinearAlgorithm end
1919

2020
function init_cacheval(alg::MyLUFactorization, A, b, u, Pl, Pr, maxiters, abstol, reltol,
21-
verbose)
21+
verbose)
2222
lu!(convert(AbstractMatrix, A))
2323
end
2424

ext/LinearSolveBandedMatricesExt.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@ module LinearSolveBandedMatricesExt
22

33
using BandedMatrices, LinearAlgebra, LinearSolve
44
import LinearSolve: defaultalg,
5-
do_factorization, init_cacheval, DefaultLinearSolver, DefaultAlgorithmChoice
5+
do_factorization, init_cacheval, DefaultLinearSolver,
6+
DefaultAlgorithmChoice
67

78
# Defaults for BandedMatrices
89
function defaultalg(A::BandedMatrix, b, oa::OperatorAssumptions{Bool})
@@ -35,14 +36,14 @@ for alg in (:SVDFactorization, :MKLLUFactorization, :DiagonalFactorization,
3536
:AppleAccelerateLUFactorization, :CholeskyFactorization)
3637
@eval begin
3738
function init_cacheval(::$(alg), ::BandedMatrix, b, u, Pl, Pr, maxiters::Int,
38-
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
39+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
3940
return nothing
4041
end
4142
end
4243
end
4344

4445
function init_cacheval(::LUFactorization, A::BandedMatrix, b, u, Pl, Pr, maxiters::Int,
45-
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
46+
abstol, reltol, verbose::Bool, assumptions::OperatorAssumptions)
4647
return lu(similar(A, 0, 0))
4748
end
4849

@@ -54,8 +55,8 @@ for alg in (:SVDFactorization, :MKLLUFactorization, :DiagonalFactorization,
5455
:AppleAccelerateLUFactorization, :QRFactorization, :LUFactorization)
5556
@eval begin
5657
function init_cacheval(::$(alg), ::Symmetric{<:Number, <:BandedMatrix}, b, u, Pl,
57-
Pr, maxiters::Int, abstol, reltol, verbose::Bool,
58-
assumptions::OperatorAssumptions)
58+
Pr, maxiters::Int, abstol, reltol, verbose::Bool,
59+
assumptions::OperatorAssumptions)
5960
return nothing
6061
end
6162
end

ext/LinearSolveBlockDiagonalsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module LinearSolveBlockDiagonalsExt
33
using LinearSolve, BlockDiagonals
44

55
function LinearSolve.init_cacheval(alg::SimpleGMRES{false}, A::BlockDiagonal, b, args...;
6-
kwargs...)
6+
kwargs...)
77
@assert ndims(A)==2 "ndims(A) == $(ndims(A)). `A` must have ndims == 2."
88
# We need to perform this check even when `zeroinit == true`, since the type of the
99
# cache is dependent on whether we are able to use the specialized dispatch.

ext/LinearSolveCUDAExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using LinearSolve.LinearAlgebra, LinearSolve.SciMLBase, LinearSolve.ArrayInterfa
66
using SciMLBase: AbstractSciMLOperator
77

88
function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactorization;
9-
kwargs...)
9+
kwargs...)
1010
if cache.isfresh
1111
fact = qr(CUDA.CuArray(cache.A))
1212
cache.cacheval = fact
@@ -18,8 +18,8 @@ function SciMLBase.solve!(cache::LinearSolve.LinearCache, alg::CudaOffloadFactor
1818
end
1919

2020
function LinearSolve.init_cacheval(alg::CudaOffloadFactorization, A, b, u, Pl, Pr,
21-
maxiters::Int, abstol, reltol, verbose::Bool,
22-
assumptions::OperatorAssumptions)
21+
maxiters::Int, abstol, reltol, verbose::Bool,
22+
assumptions::OperatorAssumptions)
2323
qr(CUDA.CuArray(A))
2424
end
2525

ext/LinearSolveEnzymeExt.jl

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ using LinearSolve
44
using LinearSolve.LinearAlgebra
55
isdefined(Base, :get_extension) ? (import Enzyme) : (import ..Enzyme)
66

7-
87
using Enzyme
98

109
using EnzymeCore
1110

12-
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
11+
function EnzymeCore.EnzymeRules.forward(
12+
func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP},
13+
alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
1314
@assert !(prob isa Const)
1415
res = func.val(prob.val, alg.val; kwargs...)
1516
if RT <: Const
@@ -26,11 +27,13 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.init)}, :
2627
error("Unsupported return type $RT")
2728
end
2829

29-
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
30+
function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
31+
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
32+
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
3033
@assert !(linsolve isa Const)
3134

3235
res = func.val(linsolve.val; kwargs...)
33-
36+
3437
if RT <: Const
3538
return res
3639
end
@@ -56,7 +59,10 @@ function EnzymeCore.EnzymeRules.forward(func::Const{typeof(LinearSolve.solve!)},
5659
return Duplicated(res, dres)
5760
end
5861

59-
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
62+
function EnzymeCore.EnzymeRules.augmented_primal(
63+
config, func::Const{typeof(LinearSolve.init)},
64+
::Type{RT}, prob::EnzymeCore.Annotation{LP}, alg::Const;
65+
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
6066
res = func.val(prob.val, alg.val; kwargs...)
6167
dres = if EnzymeRules.width(config) == 1
6268
func.val(prob.dval, alg.val; kwargs...)
@@ -77,7 +83,6 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
7783
(dval.b for dval in dres)
7884
end
7985

80-
8186
prob_d_A = if EnzymeRules.width(config) == 1
8287
prob.dval.A
8388
else
@@ -92,7 +97,10 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
9297
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, (d_A, d_b, prob_d_A, prob_d_b))
9398
end
9499

95-
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.init)}, ::Type{RT}, cache, prob::EnzymeCore.Annotation{LP}, alg::Const; kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
100+
function EnzymeCore.EnzymeRules.reverse(
101+
config, func::Const{typeof(LinearSolve.init)}, ::Type{RT},
102+
cache, prob::EnzymeCore.Annotation{LP}, alg::Const;
103+
kwargs...) where {RT, LP <: LinearSolve.LinearProblem}
96104
d_A, d_b, prob_d_A, prob_d_b = cache
97105

98106
if EnzymeRules.width(config) == 1
@@ -105,7 +113,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.i
105113
d_b .= 0
106114
end
107115
else
108-
for (_prob_d_A,_d_A,_prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
116+
for (_prob_d_A, _d_A, _prob_d_b, _d_b) in zip(prob_d_A, d_A, prob_d_b, d_b)
109117
if _d_A !== _prob_d_A
110118
_prob_d_A .+= _d_A
111119
_d_A .= 0
@@ -123,7 +131,10 @@ end
123131
# y=inv(A) B
124132
# dA −= z y^T
125133
# dB += z, where z = inv(A^T) dy
126-
function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
134+
function EnzymeCore.EnzymeRules.augmented_primal(
135+
config, func::Const{typeof(LinearSolve.solve!)},
136+
::Type{RT}, linsolve::EnzymeCore.Annotation{LP};
137+
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
127138
res = func.val(linsolve.val; kwargs...)
128139

129140
dres = if EnzymeRules.width(config) == 1
@@ -176,7 +187,9 @@ function EnzymeCore.EnzymeRules.augmented_primal(config, func::Const{typeof(Line
176187
return EnzymeCore.EnzymeRules.AugmentedReturn(res, dres, cache)
177188
end
178189

179-
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)}, ::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP}; kwargs...) where {RT, LP <: LinearSolve.LinearCache}
190+
function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.solve!)},
191+
::Type{RT}, cache, linsolve::EnzymeCore.Annotation{LP};
192+
kwargs...) where {RT, LP <: LinearSolve.LinearCache}
180193
y, dys, _linsolve, dAs, dbs = cache
181194

182195
@assert !(linsolve isa Const)
@@ -202,7 +215,7 @@ function EnzymeCore.EnzymeRules.reverse(config, func::Const{typeof(LinearSolve.s
202215
LinearSolve.defaultalg_adjoint_eval(_linsolve, dy)
203216
else
204217
error("Algorithm $(_linsolve.alg) is currently not supported by Enzyme rules on LinearSolve.jl. Please open an issue on LinearSolve.jl detailing which algorithm is missing the adjoint handling")
205-
end
218+
end
206219

207220
dA .-= z * transpose(y)
208221
db .+= z

0 commit comments

Comments
 (0)