Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/manifold.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# TODO is this necessary and should it be added to REQUIRE?
using ManifoldProjections

# wrapper for non-autonomous functions
mutable struct NonAutonomousFunction{F,autonomous}
f::F
Expand Down Expand Up @@ -52,4 +55,12 @@ function ManifoldProjection(g; nlsolve=NLSOLVEJL_SETUP(), save=true,
save_positions=save_positions)
end

function ManifoldProjection(M::Mfd; save=true) where {Mfd <: Manifold}
affect!(integrator) = retract!(M, integrator.u)
condition = (u, t, integrator) -> true
save_positions = (false,save)
DiscreteCallback(condition, affect!,
save_positions=save_positions)
end

export ManifoldProjection
30 changes: 29 additions & 1 deletion test/manifold_tests.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, Test, DiffEqBase, DiffEqCallbacks, RecursiveArrayTools
using OrdinaryDiffEq, Test, DiffEqBase, DiffEqCallbacks, RecursiveArrayTools, ManifoldProjections

u0 = ones(2,2)
f = function (du,u,p,t)
Expand Down Expand Up @@ -64,3 +64,31 @@ sol[end][1]^2 + sol[end][2]^2 ≈ 2

sol = solve(prob,Vern7(),callback=cb_t_false)
sol[end][1]^2 + sol[end][2]^2 ≈ 2

# Test using ManifoldProjections.jl
# Test the equations above now transformed to the complex plane
u0 = (1+1im) * ones(ComplexF64, 2)
function f(du,u,p,t)
@. du[:] = im * u
end
prob = ODEProblem(f,u0,(0.0,100.0))

# Each of the
S = Sphere(sqrt(2))
M = PowerManifold(S, (1,), (2,))

sol = solve(prob,Vern7())
@test !(abs2(sol[end][1]) ≈ 2)

cb = ManifoldProjection(M)
# @test isautonomous(cb.affect!)
solve(prob,Vern7(),callback=cb)
@time sol=solve(prob,Vern7(),callback=cb)
@test all(abs2.(sol[end]) .≈ 2)

# test array partitions
u₀ = ArrayPartition([1.0+im], [1.0+im])
prob = ODEProblem(f, u₀, (0.0, 100.0))

sol = solve(prob,Vern7(),callback=cb)
@test all(abs2.(sol[end]) .≈ 2)