Skip to content

Commit 0463e8a

Browse files
Merge #93
93: Prep alleviating latency issues r=charleskawczynski a=charleskawczynski This PR takes a couple steps to tackle some latency issues: - [Avoid UnionAll::DataType in struct](52c9ff0), which was shown to improve inference [here](CliMA/Thermodynamics.jl#71), and a smaller MWE [here](https://gist.github.com/charleskawczynski/61efc7b769e14eab7231e686773b9d0e) - [Allow SparseContainers to work with Arrays](c5c53ee) Co-authored-by: Charles Kawczynski <[email protected]>
2 parents 17e3b55 + c5c53ee commit 0463e8a

File tree

3 files changed

+34
-8
lines changed

3 files changed

+34
-8
lines changed

src/solvers/newtons_method.jl

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ end
327327
KrylovMethod(;
328328
jacobian_free_jvp = nothing,
329329
forcing_term = ConstantForcing(0),
330-
type = Krylov.GmresSolver,
330+
type = Val(Krylov.GmresSolver),
331331
args = (20,),
332332
kwargs = (;),
333333
solve_kwargs = (;),
@@ -347,7 +347,8 @@ where `x_prototype` is `similar` to `x` (and also to `Δx` and `f`).
347347
348348
This is primarily a wrapper for a `Krylov.KrylovSolver` from `Krylov.jl`. In
349349
`allocate_cache`, the solver is constructed with
350-
`solver = type(l, l, args..., Krylov.ktypeof(x_prototype); kwargs...)`, where
350+
`solver = type(l, l, args..., Krylov.ktypeof(x_prototype); kwargs...)` (note
351+
that `type` must be passed through in a `Val` struct), where
351352
`l = length(x_prototype)` and `Krylov.ktypeof(x_prototype)` is a subtype of
352353
`DenseVector` that can be used to store `x_prototype`. By default, the solver
353354
is a `Krylov.GmresSolver` with a Krylov subspace size of 20 (the default Krylov
@@ -387,17 +388,17 @@ each iteration of the Krylov method. If a debugger is specified, it is run
387388
before the call to `Kyrlov.solve!`.
388389
"""
389390
Base.@kwdef struct KrylovMethod{
391+
T <: Val,
390392
J <: Union{Nothing, JacobianFreeJVP},
391393
F <: ForcingTerm,
392-
T <: Type,
393394
A <: Tuple,
394395
K <: NamedTuple,
395396
S <: NamedTuple,
396397
D <: Union{Nothing, KrylovMethodDebugger},
397398
}
399+
type::T = Val(Krylov.GmresSolver)
398400
jacobian_free_jvp::J = nothing
399401
forcing_term::F = ConstantForcing(0)
400-
type::T = Krylov.GmresSolver
401402
args::A = (20,)
402403
kwargs::K = (;)
403404
solve_kwargs::S = (;)
@@ -406,9 +407,12 @@ Base.@kwdef struct KrylovMethod{
406407
debugger::D = nothing
407408
end
408409

410+
solver_type(::KrylovMethod{Val{T}}) where {T} = T
411+
409412
function allocate_cache(alg::KrylovMethod, x_prototype)
410-
(; jacobian_free_jvp, forcing_term, type, args, kwargs, debugger) = alg
411-
@assert alg.type isa Type{<:Krylov.KrylovSolver}
413+
(; jacobian_free_jvp, forcing_term, args, kwargs, debugger) = alg
414+
type = solver_type(alg)
415+
@assert type isa Type{<:Krylov.KrylovSolver}
412416
l = length(x_prototype)
413417
return (;
414418
jacobian_free_jvp_cache = isnothing(jacobian_free_jvp) ? nothing :
@@ -421,8 +425,9 @@ function allocate_cache(alg::KrylovMethod, x_prototype)
421425
end
422426

423427
function run!(alg::KrylovMethod, cache, Δx, x, f!, f, n, j = nothing)
424-
(; jacobian_free_jvp, forcing_term, type, solve_kwargs) = alg
428+
(; jacobian_free_jvp, forcing_term, solve_kwargs) = alg
425429
(; disable_preconditioner, verbose, debugger) = alg
430+
type = solver_type(alg)
426431
(; jacobian_free_jvp_cache, forcing_term_cache, solver, debugger_cache) =
427432
cache
428433
jΔx!(jΔx, Δx) = isnothing(jacobian_free_jvp) ? mul!(jΔx, j, Δx) :

src/sparse_containers.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,15 @@ struct SparseContainer{SIM, T}
2525
function SparseContainer(
2626
compressed_data::T,
2727
sparse_index_map::Tuple
28-
) where {N, ET, T <: NTuple{N, ET}}
28+
) where {T}
2929
@assert all(map(x-> eltype(compressed_data) .== typeof(x), compressed_data))
3030
return new{sparse_index_map, T}(compressed_data)
3131
end
3232
end
3333

3434
Base.parent(sc::SparseContainer) = sc.data
3535
sc_eltype(::Type{NTuple{N, T}}) where {N, T} = T
36+
sc_eltype(::Type{T}) where {ET, T <: AbstractArray{ET}} = ET
3637
sc_eltype(::SparseContainer{SIM, T}) where {SIM, T} = sc_eltype(T)
3738
@inline function Base.getindex(sc::SparseContainer, i::Int)
3839
return _getindex_sparse(sc, Val(i))::sc_eltype(sc)

test/sparse_containers.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,24 @@ using Test
1919

2020
@test_throws ErrorException("No index 2 found in sparse index map (1, 3, 5, 7)") v[2]
2121
@test_throws ErrorException("No index 8 found in sparse index map (1, 3, 5, 7)") v[8]
22+
@inferred v[7]
23+
24+
a1 = ones(3) .* 1
25+
a2 = ones(3) .* 2
26+
a3 = ones(3) .* 3
27+
a4 = ones(3) .* 4
28+
v = SparseContainer([a1,a2,a3,a4], (1,3,5,7))
29+
@test v[1] == ones(3) .* 1
30+
@test v[3] == ones(3) .* 2
31+
@test v[5] == ones(3) .* 3
32+
@test v[7] == ones(3) .* 4
33+
34+
@test parent(v)[1] == ones(3) .* 1
35+
@test parent(v)[2] == ones(3) .* 2
36+
@test parent(v)[3] == ones(3) .* 3
37+
@test parent(v)[4] == ones(3) .* 4
38+
39+
@test_throws ErrorException("No index 2 found in sparse index map (1, 3, 5, 7)") v[2]
40+
@test_throws ErrorException("No index 8 found in sparse index map (1, 3, 5, 7)") v[8]
41+
@inferred v[7]
2242
end

0 commit comments

Comments
 (0)