diff --git a/src/manifold.jl b/src/manifold.jl index 1e451773..0044cc92 100644 --- a/src/manifold.jl +++ b/src/manifold.jl @@ -1,3 +1,6 @@ +# TODO is this necessary and should it be added to REQUIRE? +using ManifoldProjections + # wrapper for non-autonomous functions mutable struct NonAutonomousFunction{F,autonomous} f::F @@ -52,4 +55,12 @@ function ManifoldProjection(g; nlsolve=NLSOLVEJL_SETUP(), save=true, save_positions=save_positions) end +function ManifoldProjection(M::Mfd; save=true) where {Mfd <: Manifold} + affect!(integrator) = retract!(M, integrator.u) + condition = (u, t, integrator) -> true + save_positions = (false,save) + DiscreteCallback(condition, affect!, + save_positions=save_positions) +end + export ManifoldProjection diff --git a/test/manifold_tests.jl b/test/manifold_tests.jl index 2c343f70..821587a6 100644 --- a/test/manifold_tests.jl +++ b/test/manifold_tests.jl @@ -1,4 +1,4 @@ -using OrdinaryDiffEq, Test, DiffEqBase, DiffEqCallbacks, RecursiveArrayTools +using OrdinaryDiffEq, Test, DiffEqBase, DiffEqCallbacks, RecursiveArrayTools, ManifoldProjections u0 = ones(2,2) f = function (du,u,p,t) @@ -64,3 +64,31 @@ sol[end][1]^2 + sol[end][2]^2 ≈ 2 sol = solve(prob,Vern7(),callback=cb_t_false) sol[end][1]^2 + sol[end][2]^2 ≈ 2 + +# Test using ManifoldProjections.jl +# Test the equations above now transformed to the complex plane +u0 = (1+1im) * ones(ComplexF64, 2) +function f(du,u,p,t) + @. du[:] = im * u +end +prob = ODEProblem(f,u0,(0.0,100.0)) + +# Each of the +S = Sphere(sqrt(2)) +M = PowerManifold(S, (1,), (2,)) + +sol = solve(prob,Vern7()) +@test !(abs2(sol[end][1]) ≈ 2) + +cb = ManifoldProjection(M) +# @test isautonomous(cb.affect!) +solve(prob,Vern7(),callback=cb) +@time sol=solve(prob,Vern7(),callback=cb) +@test all(abs2.(sol[end]) .≈ 2) + +# test array partitions +u₀ = ArrayPartition([1.0+im], [1.0+im]) +prob = ODEProblem(f, u₀, (0.0, 100.0)) + +sol = solve(prob,Vern7(),callback=cb) +@test all(abs2.(sol[end]) .≈ 2)