diff --git a/Project.toml b/Project.toml index 9f03e65..2a70ea7 100644 --- a/Project.toml +++ b/Project.toml @@ -8,9 +8,11 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Krylov = "ba0b0d4f-ebba-5204-a429-3ac8c609bfb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [compat] Enzyme = "0.13" Krylov = "0.10.1" LinearAlgebra = "1.10" +StaticArrays = "1.9.13" julia = "1.10" diff --git a/examples/Project.toml b/examples/Project.toml index 35ffa65..2b10a14 100644 --- a/examples/Project.toml +++ b/examples/Project.toml @@ -10,6 +10,7 @@ NewtonKrylov = "0be81120-40bf-4f8b-adf0-26103efb66f1" Observables = "510215fc-4207-5dde-b226-833fc4488ee2" OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SummationByPartsOperators = "9f78cca6-572e-554e-b819-917d2f1cf240" [sources] diff --git a/examples/heat_1D.jl b/examples/heat_1D.jl index 723b9eb..bcbcdca 100644 --- a/examples/heat_1D.jl +++ b/examples/heat_1D.jl @@ -5,21 +5,17 @@ using NewtonKrylov using CairoMakie include(joinpath(dirname(pathof(NewtonKrylov)), "..", "examples", "implicit.jl")) +include(joinpath(dirname(pathof(NewtonKrylov)), "..", "examples", "stencils.jl")) # ## Heat 1D # $ \frac{\partial u(x, t)}{\partial t} = a * \frac{\partial^2 u(x, t)}{\partial x^2 $ -function heat_1D!(du, u, (a, Δx, bc!), t) - N = length(u) +function heat_1D!(du, U, (a, Δx, stencil), t) + N = length(U) - ## Enforce the boundary condition - bc!(u) - du[1] = 0 - du[end] = 0 - - ## Only compute within - for i in 2:(N - 1) - du[i] = a * (u[i + 1] - 2u[i] + u[i - 1]) / Δx^2 + for i in 1:N + u = stencil(U, i) + du[i] = a * D²ₓ(u, Δx) end return end @@ -31,19 +27,12 @@ end # L = 1 # x ∈ (0,L) -function bc!(u) - u[1] = 0 - return u[end] = 0 -end - -function periodic_bc!(u) - u[1] = u[end - 1] - return u[end] = u[2] -end - # inital condition -f(x) = 4x * (1 - x) +f(x) = sin(π * x) + +dirchlet = ThreePointStencil(Constant(0.0, 0.0)) +periodic = ThreePointStencil(Periodic()) a = 0.5 @@ -54,7 +43,7 @@ using LinearAlgebra # ## Investigate the Jacobian's # ### Euler -J = jacobian(G_Euler!, heat_1D!, zeros(N), (a, 1 / (N + 1), bc!), 0.1, 0.0) +J = jacobian(G_Euler!, heat_1D!, zeros(N), (a, 1 / (N + 1), dirchlet), 0.1, 0.0) # Rank: diff --git a/examples/implicit.jl b/examples/implicit.jl index ff521b5..6918e2f 100644 --- a/examples/implicit.jl +++ b/examples/implicit.jl @@ -5,10 +5,10 @@ using NewtonKrylov # ## Implicit Euler -function G_Euler!(res, uₙ, Δt, f!, du, u, p, t) - f!(du, u, p, t) +function G_Euler!(res, uₙ, Δt, f!, du, U, p, t) + f!(du, U, p, t) - res .= uₙ .+ Δt .* du .- u + res .= uₙ .+ Δt .* du .- U return nothing end @@ -43,7 +43,10 @@ function jacobian(G!, f!, uₙ, p, Δt, t) du = zero(uₙ) res = zero(uₙ) - F!(res, u, (uₙ, Δt, du, p, t)) = G!(res, uₙ, Δt, f!, du, u, p, t) + function F!(res, u, P) + (uₙ, Δt, du, p, t) = P + return G!(res, uₙ, Δt, f!, du, u, p, t) + end J = NewtonKrylov.JacobianOperator(F!, res, u, (uₙ, Δt, du, p, t)) return collect(J) diff --git a/examples/stencils.jl b/examples/stencils.jl new file mode 100644 index 0000000..0afcc50 --- /dev/null +++ b/examples/stencils.jl @@ -0,0 +1,67 @@ +abstract type Boundary end + +struct Periodic <: Boundary end + +Base.@propagate_inbounds function get(::Periodic, u, i) + N = length(u) + if i == 1 + return u[N] + elseif i == N + return u[i] + else + return u[i] + end +end + +struct Constant{T} <: Boundary + left::T + right::T +end + +Base.@propagate_inbounds function get(c::Constant, u, i) + N = length(u) + if i == 1 + return c.left + elseif i == N + return c.right + else + return u[i] + end +end + +using StaticArrays + +struct ThreePointStencil{B <: Boundary} + b::B +end + +function (stencil::ThreePointStencil)(u::AbstractVector, i) + @boundscheck checkbounds(u, i) + @inbounds begin + l = get(stencil, u, i - 1) + c = u[i] + r = get(stencil, u, i + r) + end + return SVector((l, c, r)) +end + +function D²ₓ(u::StaticVector, Δx) + return (u[1] - 2u[2] + u[3]) / Δx^2 +end + + +# struct Stencil{N,B} +# boundaries::B +# end + +# function (stencil::Stencil{N})(u::AbstractArray, idxs...) where N +# shape = size(u) +# @assert length(shape) == length(stencil.boundaries) == length(idxs) == N + +# region = -1:1:1 +# ntuple(Val(N)) do dim +# i = idxs[dim] + + +# end +# end