diff --git a/src/rode_solve.jl b/src/rode_solve.jl index 16b4637bce..8212b8d76d 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -1,38 +1,191 @@ -struct NNRODE{C, W, O, P, K} <: NeuralPDEAlgorithm +struct NNRODE{C, W, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <: + NeuralPDEAlgorithm + chain::C + W::W + opt::O + init_params::P + autodiff::Bool + batch::B + strategy::S + kwargs::K +end +function NNRODE(chain, W, opt, init_params = nothing; + strategy = nothing, + autodiff = false, batch = nothing, kwargs...) + NNRODE(chain, W, opt, init_params, autodiff, batch, strategy, kwargs) +end + +mutable struct RODEPhi{C, T, U, S} chain::C - W::W - opt::O - init_params::P - autodiff::Bool - kwargs::K -end -function NNRODE(chain, W, opt = Optim.BFGS(), init_params = nothing; autodiff = false, - kwargs...) - if init_params === nothing - if chain isa Flux.Chain - init_params, re = Flux.destructure(chain) - else - error("Only Flux is support here right now") - end + t0::T + u0::U + st::S + + function RODEPhi(chain::Lux.AbstractExplicitLayer, t::Number, u0, st) + new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st) + end + + function RODEPhi(re::Optimisers.Restructure, t, u0) + new{typeof(re), typeof(t), typeof(u0), Nothing}(re, t, u0, nothing) + end +end + +function generate_phi_θ_rode(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing) + θ, st = Lux.setup(Random.default_rng(), chain) + RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ) +end + +function generate_phi_θ_rode(chain::Lux.AbstractExplicitLayer, t, u0, init_params) + θ, st = Lux.setup(Random.default_rng(), chain) + RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params) +end + +function generate_phi_θ_rode(chain::Flux.Chain, t, u0, init_params::Nothing) + θ, re = Flux.destructure(chain) + RODEPhi(re, t, u0), θ +end + +function generate_phi_θ_rode(chain::Flux.Chain, t, u0, init_params) + θ, re = Flux.destructure(chain) + RODEPhi(re, t, u0), init_params +end + +function (f::RODEPhi{C, T, U})(t::Number, W::Number, + θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} + y, st = f.chain(adapt(parameterless_type(θ), [t ; W]), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + f.u0 + (t - f.t0) * first(y) +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector, + θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} +# Batch via data as row vectors + y, st = f.chain(adapt(parameterless_type(θ), [t W]'), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + f.u0 .+ (t' .- f.t0) .* y +end + +function (f::RODEPhi{C, T, U})(t::Number, W::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U} +y, st = f.chain(adapt(parameterless_type(θ), [t W]), θ, f.st) +ChainRulesCore.@ignore_derivatives f.st = st +f.u0 .+ (t .- f.t0) .* y +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector, + θ) where {C <: Lux.AbstractExplicitLayer, T, U} +# Batch via data as row vectors +y, st = f.chain(adapt(parameterless_type(θ), [t W]'), θ, f.st) +ChainRulesCore.@ignore_derivatives f.st = st +f.u0 .+ (t' .- f.t0) .* y +end + + +function (f::RODEPhi{C, T, U})(t::Number, w::Number, + θ) where {C <: Optimisers.Restructure, T, U <: Number} +f.u0 + (t - f.t0) * first(f.chain(θ)(adapt(parameterless_type(θ), [t, w]))) +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector, + θ) where {C <: Optimisers.Restructure, T, U <: Number} +f.u0 .+ (t' .- f.t0) .* f.chain(θ)(adapt(parameterless_type(θ), [t W]')) +end + +function (f::RODEPhi{C, T, U})(t::Number, w::Number, θ) where {C <: Optimisers.Restructure, T, U} +f.u0 + (t - f.t0) * f.chain(θ)(adapt(parameterless_type(θ), [t])) +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, w::AbstractVector, + θ) where {C <: Optimisers.Restructure, T, U} +f.u0 .+ (t .- f.t0) .* f.chain(θ)(adapt(parameterless_type(θ), [t, W]')) +end + +function rode_dfdx end + +function rode_dfdx(phi::RODEPhi{C, T, U}, t::Number, W::Number, θ, + autodiff::Bool) where {C, T, U <: Number} + if autodiff + ForwardDiff.derivative(t -> phi(t, W, θ), t) else - init_params = init_params + (phi(t + sqrt(eps(typeof(t))), W, θ) - phi(t, W, θ)) / sqrt(eps(typeof(t))) end - NNRODE(chain, W, opt, init_params, autodiff, kwargs) end -function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, - alg::NeuralPDEAlgorithm, - args...; - dt, - timeseries_errors = true, - save_everystep = true, - adaptive = false, - abstol = 1.0f-6, - verbose = false, - maxiters = 100) - DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!") +function rode_dfdx(phi::RODEPhi{C, T, U}, t::Number, W::Number, θ, + autodiff::Bool) where {C, T, U <: AbstractVector} + if autodiff + ForwardDiff.jacobian(t -> phi(t, W, θ), t) + else + (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, W, θ)) / sqrt(eps(typeof(t))) + end +end + +function rode_dfdx(phi::RODEPhi, t::AbstractVector, W::AbstractVector, θ, autodiff::Bool) + if autodiff + ForwardDiff.jacobian(t -> phi(t, W, θ), t) + else + (phi(t .+ sqrt(eps(eltype(t))), W, θ) - phi(t, W, θ)) ./ sqrt(eps(eltype(t))) + end +end + +function inner_loss end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::Number, W::Number, θ, + p) where {C, T, U <: Number} + sum(abs2, rode_dfdx(phi, t, W, θ, autodiff) - f(phi(t, W, θ), p, t, W)) +end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, W::AbstractVector, θ, + p) where {C, T, U <: Number} + out = phi(t, W, θ) + fs = reduce(hcat, [f(out[i], p, t[i], W[i]) for i in 1:size(out, 2)]) + dxdtguess = Array(rode_dfdx(phi, t, W, θ, autodiff)) + sum(abs2, dxdtguess .- fs) / length(t) +end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::Number, W::Number, θ, + p) where {C, T, U} + sum(abs2, rode_dfdx(phi, t, W, θ, autodiff) .- f(phi(t, W, θ), p, t, W)) +end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, W::AbstractVector, θ, + p) where {C, T, U} + out = Array(phi(t, W, θ)) + arrt = Array(t) + fs = reduce(hcat, [f(out[:, i], p, arrt[i], W[i]) for i in 1:size(out, 2)]) + dxdtguess = Array(rode_dfdx(phi, t, W, θ, autodiff)) + sum(abs2, dxdtguess .- fs) / length(t) +end + +function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, W, p, batch) + ts = tspan[1]:(strategy.dx):tspan[2] + # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken + println(typeof(W)) + function loss(θ, _) + if batch + sum(abs2, [inner_loss(phi, f, autodiff, ts, W[j, :], θ, p) for j in 1:size(W)[1]]) + else + sum(abs2, [sum(abs2, [inner_loss(phi, f, autodiff, t, W[j, :][i], θ, p) for (i, t) in enumerate(ts)]) for j in 1:size(W)[1]]) + end + end + optf = OptimizationFunction(loss, Optimization.AutoZygote()) +end + +function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem, + alg::NNRODE, + args...; + dt = nothing, + trajectories = 100, + timeseries_errors = true, + save_everystep = true, + adaptive = false, + abstol = 1.0f-6, + reltol = 1.0f-3, + verbose = false, + saveat = nothing, + maxiters = nothing) u0 = prob.u0 + W = alg.W tspan = prob.tspan f = prob.f p = prob.p @@ -42,75 +195,33 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, chain = alg.chain opt = alg.opt autodiff = alg.autodiff - Wg = alg.W + #train points generation - ts = tspan[1]:dt:tspan[2] init_params = alg.init_params - if chain isa FastChain - #The phi trial solution - if u0 isa Number - phi = (t, W, θ) -> u0 + - (t - tspan[1]) * - first(chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), - θ)) - else - phi = (t, W, θ) -> u0 + - (t - tspan[1]) * - chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), θ) - end - else - _, re = Flux.destructure(chain) - #The phi trial solution - if u0 isa Number - phi = (t, W, θ) -> u0 + - (t - t0) * - first(re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W]))) - else - phi = (t, W, θ) -> u0 + - (t - t0) * - re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W])) - end - end + phi, init_params = generate_phi_θ_rode(chain, t0, u0, init_params) - if autodiff - # dfdx = (t,W,θ) -> ForwardDiff.derivative(t->phi(t,θ),t) - else - dfdx = (t, W, θ) -> (phi(t + sqrt(eps(t)), W, θ) - phi(t, W, θ)) / sqrt(eps(t)) - end + strategy = isnothing(alg.strategy) ? GridTraining(dt) : alg.strategy + batch = isnothing(alg.batch) ? false : alg.batch - function inner_loss(t, W, θ) - sum(abs, dfdx(t, W, θ) - f(phi(t, W, θ), p, t, W)) - end - Wprob = NoiseProblem(Wg, tspan) - Wsol = solve(Wprob; dt = dt) - W = NoiseGrid(ts, Wsol.W) - function loss(θ) - sum(abs2, inner_loss(ts[i], W.W[i], θ) for i in 1:length(ts)) # sum(abs2,phi(tspan[1],θ) - u0) + W_prob = NoiseProblem(W, tspan) + W_en = EnsembleProblem(W_prob) + W_sim = solve(W_en; dt = dt, trajectories = trajectories) + W_bf = Zygote.Buffer(rand(length(W_sim), length(W_sim[1]))) + for (i, sol) in enumerate(W_sim) + W_bf[i, :] = sol end + optf = generate_loss(strategy, phi, f, autodiff::Bool, tspan, W_bf, p, batch) + iteration = 0 callback = function (p, l) - Wprob = NoiseProblem(Wg, tspan) - Wsol = solve(Wprob; dt = dt) - W = NoiseGrid(ts, Wsol.W) - verbose && println("Current loss is: $l") + iteration += 1 + verbose && println("Current loss is: $l, Iteration: $iteration") l < abstol end - #res = DiffEqFlux.sciml_train(loss, init_params, opt; cb = callback, maxiters = maxiters, - # alg.kwargs...) - - #solutions at timepoints - noiseproblem = NoiseProblem(Wg, tspan) - W = solve(noiseproblem; dt = dt) - if u0 isa Number - u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)] - else - u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)] - end - sol = DiffEqBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false) - DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) - sol + optprob = OptimizationProblem(optf, init_params) + res = solve(optprob, opt; callback, maxiters, alg.kwargs...) + + res, (t, W) -> phi(t, W, res.u) end #solve diff --git a/test/NNRODE_tests.jl b/test/NNRODE_tests.jl index da8f87ec7f..fa5cd017c7 100644 --- a/test/NNRODE_tests.jl +++ b/test/NNRODE_tests.jl @@ -1,6 +1,8 @@ using Flux, OptimizationOptimisers, StochasticDiffEq, DiffEqNoiseProcess, Optim, Test +import Lux, OptimizationOptimJL using NeuralPDE + using Random Random.seed!(100) @@ -11,15 +13,40 @@ u0 = 1.0f0 dt = 1 / 50.0f0 W = WienerProcess(0.0, 0.0, nothing) prob = RODEProblem(linear, u0, tspan, noise = W) +opt = OptimizationOptimisers.Adam(0.01) + +W_test = solve(NoiseProblem(W, tspan), dt = dt) +prob1 = RODEProblem(linear, u0, tspan, noise = W_test) +analytical_sol = solve(prob1, RandomEM(), dt = dt) + +ts = tspan[1]:dt:tspan[2] + chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) -opt = OptimizationOptimisers.Adam(1e-4) -sol = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, - abstol = 1e-10, maxiters = 3000) -W2 = NoiseWrapper(sol.W) -prob1 = RODEProblem(linear, u0, tspan, noise = W2) -sol2 = solve(prob1, RandomEM(), dt = dt) -err = Flux.mse(sol.u, sol2.u) -@test err < 0.3 +sol1 = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + +chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) +sol2 = solve(prob, NeuralPDE.NNRODE(chain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + + +luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) +sol1 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + +luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) +sol2 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + + println("Test Case 2") linear = (u, p, t, W) -> t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) - @@ -29,12 +56,33 @@ u0 = 1.0f0 dt = 1 / 100.0f0 W = WienerProcess(0.0, 0.0, nothing) prob = RODEProblem(linear, u0, tspan, noise = W) -chain = Flux.Chain(Dense(2, 32, sigmoid), Dense(32, 32, sigmoid), Dense(32, 1)) -opt = OptimizationOptimisers.Adam(1e-3) -sol = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, - abstol = 1e-10, maxiters = 2000) -W2 = NoiseWrapper(sol.W) -prob1 = RODEProblem(linear, u0, tspan, noise = W2) -sol2 = solve(prob1, RandomEM(), dt = dt) -err = Flux.mse(sol.u, sol2.u) -@test err < 0.4 +W_test = solve(NoiseProblem(W, tspan), dt = dt) +prob1 = RODEProblem(linear, u0, tspan, noise = W_test) +analytical_sol = solve(prob1, RandomEM(), dt = dt) + +ts = tspan[1]:dt:tspan[2] + +chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) +sol1 = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + +chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) +sol2 = solve(prob, NeuralPDE.NNRODE(chain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + + luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) + sol1 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) + err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) + @test err < 0.3 + + luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) + sol2 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) + err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) + @test err < 0.3 + diff --git a/test/runtests.jl b/test/runtests.jl index d9846cd54b..e249d83a5c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,12 +35,9 @@ end @time @safetestset "AdaptiveLoss" begin include("adaptive_loss_tests.jl") end end - #= - # Fails because it uses sciml_train if GROUP == "All" || GROUP == "NNRODE" @time @safetestset "NNRODE" begin include("NNRODE_tests.jl") end end - =# if GROUP == "All" || GROUP == "Forward" @time @safetestset "Forward" begin include("forward_tests.jl") end