From 78183a1cbdf9a8b48d2de9862808cff09ee41bb0 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Wed, 10 Aug 2022 13:45:36 -0400 Subject: [PATCH 1/6] Enable RODE over an ensemble of process --- src/rode_solve.jl | 264 +++++++++++++++++++++++++++++++--------------- 1 file changed, 177 insertions(+), 87 deletions(-) diff --git a/src/rode_solve.jl b/src/rode_solve.jl index 16b4637bce..b6cbca6f37 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -1,38 +1,144 @@ -struct NNRODE{C, W, O, P, K} <: NeuralPDEAlgorithm +struct NNRODE{C, W, O, P, B, K, S <: Union{Nothing, AbstractTrainingStrategy}} <: + NeuralPDEAlgorithm + chain::C + W::W + opt::O + init_params::P + autodiff::Bool + batch::B + strategy::S + kwargs::K +end +function NNRODE(chain, W, opt, init_params = nothing; + strategy = nothing, + autodiff = false, batch = nothing, kwargs...) + NNRODE(chain, W, opt, init_params, autodiff, batch, strategy, kwargs) +end + +mutable struct RODEPhi{C, T, U, S} chain::C - W::W - opt::O - init_params::P - autodiff::Bool - kwargs::K -end -function NNRODE(chain, W, opt = Optim.BFGS(), init_params = nothing; autodiff = false, - kwargs...) - if init_params === nothing - if chain isa Flux.Chain - init_params, re = Flux.destructure(chain) - else - error("Only Flux is support here right now") - end + t0::T + u0::U + function RODEPhi(re::Optimisers.Restructure, t, u0) + new{typeof(re), typeof(t), typeof(u0), Nothing}(re, t, u0) + end +end + +function generate_phi_θ(chain::Flux.Chain, t, u0, init_params::Nothing) + θ, re = Flux.destructure(chain) + RODEPhi(re, t, u0), θ +end + +function generate_phi_θ(chain::Flux.Chain, t, u0, init_params) + θ, re = Flux.destructure(chain) + RODEPhi(re, t, u0), init_params +end + +function (f::RODEPhi{C, T, U})(t::Number, w::Number, + θ) where {C <: Optimisers.Restructure, T, U <: Number} +f.u0 + (t - f.t0) * first(f.chain(θ)(adapt(parameterless_type(θ), [t, w]))) +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector, + θ) where {C <: Optimisers.Restructure, T, U <: Number} +f.u0 .+ (t' .- f.t0) .* f.chain(θ)(adapt(parameterless_type(θ), [t W]')) +end + +function (f::RODEPhi{C, T, U})(t::Number, w::Number, θ) where {C <: Optimisers.Restructure, T, U} +f.u0 + (t - f.t0) * f.chain(θ)(adapt(parameterless_type(θ), [t])) +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, w::AbstractVector, + θ) where {C <: Optimisers.Restructure, T, U} +f.u0 .+ (t .- f.t0) .* f.chain(θ)(adapt(parameterless_type(θ), [t, W]')) +end + +function rode_dfdx end + +function rode_dfdx(phi::RODEPhi{C, T, U}, t::Number, W::Number, θ, + autodiff::Bool) where {C, T, U <: Number} + if autodiff + ForwardDiff.derivative(t -> phi(t, W, θ), t) + else + (phi(t + sqrt(eps(typeof(t))), W, θ) - phi(t, W, θ)) / sqrt(eps(typeof(t))) + end +end + +function rode_dfdx(phi::RODEPhi{C, T, U}, t::Number, W::Number, θ, + autodiff::Bool) where {C, T, U <: AbstractVector} + if autodiff + ForwardDiff.jacobian(t -> phi(t, W, θ), t) + else + (phi(t + sqrt(eps(typeof(t))), θ) - phi(t, W, θ)) / sqrt(eps(typeof(t))) + end +end + +function rode_dfdx(phi::RODEPhi, t::AbstractVector, W::AbstractVector, θ, autodiff::Bool) + if autodiff + ForwardDiff.jacobian(t -> phi(t, W, θ), t) else - init_params = init_params + (phi(t .+ sqrt(eps(eltype(t))), W, θ) - phi(t, W, θ)) ./ sqrt(eps(eltype(t))) end - NNRODE(chain, W, opt, init_params, autodiff, kwargs) end -function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, - alg::NeuralPDEAlgorithm, - args...; - dt, - timeseries_errors = true, - save_everystep = true, - adaptive = false, - abstol = 1.0f-6, - verbose = false, - maxiters = 100) - DiffEqBase.isinplace(prob) && error("Only out-of-place methods are allowed!") +function inner_loss end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::Number, W::Number, θ, + p) where {C, T, U <: Number} + sum(abs2, rode_dfdx(phi, t, W, θ, autodiff) - f(phi(t, W, θ), p, t, W)) +end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, W::AbstractVector, θ, + p) where {C, T, U <: Number} + out = phi(t, W, θ) + fs = reduce(hcat, [f(out[i], p, t[i], W[i]) for i in 1:size(out, 2)]) + dxdtguess = Array(rode_dfdx(phi, t, W, θ, autodiff)) + sum(abs2, dxdtguess .- fs) / length(t) +end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::Number, W::Number, θ, + p) where {C, T, U} + sum(abs2, rode_dfdx(phi, t, W, θ, autodiff) .- f(phi(t, W, θ), p, t, W)) +end + +function inner_loss(phi::RODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, W::AbstractVector, θ, + p) where {C, T, U} + out = Array(phi(t, W, θ)) + arrt = Array(t) + fs = reduce(hcat, [f(out[:, i], p, arrt[i], W[i]) for i in 1:size(out, 2)]) + dxdtguess = Array(rode_dfdx(phi, t, W, θ, autodiff)) + sum(abs2, dxdtguess .- fs) / length(t) +end + +function generate_loss(strategy::GridTraining, phi, f, autodiff::Bool, tspan, W, p, batch) + ts = tspan[1]:(strategy.dx):tspan[2] + + # sum(abs2,inner_loss(t,θ) for t in ts) but Zygote generators are broken + println(typeof(W)) + function loss(θ, _) + if batch + sum(abs2, [inner_loss(phi, f, autodiff, ts, W[j, :], θ, p) for j in 1:size(W)[1]]) + else + sum(abs2, [sum(abs2, [inner_loss(phi, f, autodiff, t, W[j, :][i], θ, p) for (i, t) in enumerate(ts)]) for j in 1:size(W)[1]]) + end + end + optf = OptimizationFunction(loss, Optimization.AutoZygote()) +end +function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem, + alg::NNRODE, + args...; + dt = nothing, + timeseries_errors = true, + save_everystep = true, + adaptive = false, + abstol = 1.0f-6, + reltol = 1.0f-3, + verbose = false, + saveat = nothing, + maxiters = nothing) u0 = prob.u0 + W = alg.W tspan = prob.tspan f = prob.f p = prob.p @@ -42,75 +148,59 @@ function DiffEqBase.solve(prob::DiffEqBase.AbstractRODEProblem, chain = alg.chain opt = alg.opt autodiff = alg.autodiff - Wg = alg.W + #train points generation - ts = tspan[1]:dt:tspan[2] init_params = alg.init_params - if chain isa FastChain - #The phi trial solution - if u0 isa Number - phi = (t, W, θ) -> u0 + - (t - tspan[1]) * - first(chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), - θ)) - else - phi = (t, W, θ) -> u0 + - (t - tspan[1]) * - chain(adapt(DiffEqBase.parameterless_type(θ), [t, W]), θ) - end - else - _, re = Flux.destructure(chain) - #The phi trial solution - if u0 isa Number - phi = (t, W, θ) -> u0 + - (t - t0) * - first(re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W]))) - else - phi = (t, W, θ) -> u0 + - (t - t0) * - re(θ)(adapt(DiffEqBase.parameterless_type(θ), [t, W])) - end - end + phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + strategy = isnothing(alg.strategy) ? GridTraining(dt) : alg.strategy + batch = isnothing(alg.batch) ? false : alg.batch - if autodiff - # dfdx = (t,W,θ) -> ForwardDiff.derivative(t->phi(t,θ),t) - else - dfdx = (t, W, θ) -> (phi(t + sqrt(eps(t)), W, θ) - phi(t, W, θ)) / sqrt(eps(t)) - end - - function inner_loss(t, W, θ) - sum(abs, dfdx(t, W, θ) - f(phi(t, W, θ), p, t, W)) - end - Wprob = NoiseProblem(Wg, tspan) - Wsol = solve(Wprob; dt = dt) - W = NoiseGrid(ts, Wsol.W) - function loss(θ) - sum(abs2, inner_loss(ts[i], W.W[i], θ) for i in 1:length(ts)) # sum(abs2,phi(tspan[1],θ) - u0) + W_prob = NoiseProblem(W, tspan) + W_en = EnsembleProblem(W_prob) + W_sim = solve(W_en; dt = dt, trajectories = 100) + W_bf = Zygote.Buffer(rand(length(W_sim), length(W_sim[1]))) + for (i, sol) in enumerate(W_sim) + W_bf[i, :] = sol end + optf = generate_loss(strategy, phi, f, autodiff::Bool, tspan, W_bf, p, batch) + iteration = 0 callback = function (p, l) - Wprob = NoiseProblem(Wg, tspan) - Wsol = solve(Wprob; dt = dt) - W = NoiseGrid(ts, Wsol.W) - verbose && println("Current loss is: $l") + iteration += 1 + verbose && println("Current loss is: $l, Iteration: $iteration") l < abstol end - #res = DiffEqFlux.sciml_train(loss, init_params, opt; cb = callback, maxiters = maxiters, - # alg.kwargs...) + + optprob = OptimizationProblem(optf, init_params) + res = solve(optprob, opt; callback, maxiters, alg.kwargs...) #solutions at timepoints - noiseproblem = NoiseProblem(Wg, tspan) - W = solve(noiseproblem; dt = dt) - if u0 isa Number - u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)] - else - u = [(phi(ts[i], W.W[i], res.minimizer)) for i in 1:length(ts)] - end + # if saveat isa Number + # ts = tspan[1]:saveat:tspan[2] + # elseif saveat isa AbstractArray + # ts = saveat + # elseif dt !== nothing + # ts = tspan[1]:dt:tspan[2] + # elseif save_everystep + # ts = range(tspan[1], tspan[2], length = 100) + # else + # ts = [tspan[1], tspan[2]] + # end + + # if u0 isa Number + # u = [first(phi(t, res.u)) for t in ts] + # else + # u = [phi(t, res.u) for t in ts] + # end - sol = DiffEqBase.build_solution(prob, alg, ts, u, W = W, calculate_error = false) - DiffEqBase.has_analytic(prob.f) && - DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - dense_errors = false) - sol + # sol = DiffEqBase.build_solution(prob, alg, ts, u; + # k = res, dense = true, + # interp = NNODEInterpolation(phi, res.u), + # calculate_error = false, + # retcode = :Success) + # DiffEqBase.has_analytic(prob.f) && + # DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, + # dense_errors = false) + res, phi end #solve From ff2c4f1b419408cd694f87bdfe50ba42207c0dea Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Wed, 10 Aug 2022 15:19:36 -0400 Subject: [PATCH 2/6] Return res, and phi as sol --- src/rode_solve.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rode_solve.jl b/src/rode_solve.jl index b6cbca6f37..e1882061d6 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -202,5 +202,5 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem, # DiffEqBase.has_analytic(prob.f) && # DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, # dense_errors = false) - res, phi + res, u(t, W) -> phi(t, W, res.u) end #solve From 2ec48eed2f41f370ef57f177319fe7c73c3bb959 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 16 Aug 2022 03:47:30 -0400 Subject: [PATCH 3/6] Solve using LuxChains --- src/rode_solve.jl | 81 +++++++++++++++++++++++++++++------------------ 1 file changed, 51 insertions(+), 30 deletions(-) diff --git a/src/rode_solve.jl b/src/rode_solve.jl index e1882061d6..cf0a818c67 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -19,11 +19,27 @@ mutable struct RODEPhi{C, T, U, S} chain::C t0::T u0::U + st::S + + function RODEPhi(chain::Lux.AbstractExplicitLayer, t::Number, u0, st) + new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st) + end + function RODEPhi(re::Optimisers.Restructure, t, u0) - new{typeof(re), typeof(t), typeof(u0), Nothing}(re, t, u0) + new{typeof(re), typeof(t), typeof(u0), Nothing}(re, t, u0, nothing) end end +function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing) + θ, st = Lux.setup(Random.default_rng(), chain) + RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ) +end + +function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) + θ, st = Lux.setup(Random.default_rng(), chain) + RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params) +end + function generate_phi_θ(chain::Flux.Chain, t, u0, init_params::Nothing) θ, re = Flux.destructure(chain) RODEPhi(re, t, u0), θ @@ -34,6 +50,36 @@ function generate_phi_θ(chain::Flux.Chain, t, u0, init_params) RODEPhi(re, t, u0), init_params end +function (f::RODEPhi{C, T, U})(t::Number, W::Number, + θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} + y, st = f.chain(adapt(parameterless_type(θ), [t ; W]), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + f.u0 + (t - f.t0) * first(y) +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector, + θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} +# Batch via data as row vectors + y, st = f.chain(adapt(parameterless_type(θ), [t W]'), θ, f.st) + ChainRulesCore.@ignore_derivatives f.st = st + f.u0 .+ (t' .- f.t0) .* y +end + +function (f::RODEPhi{C, T, U})(t::Number, W::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U} +y, st = f.chain(adapt(parameterless_type(θ), [t W]), θ, f.st) +ChainRulesCore.@ignore_derivatives f.st = st +f.u0 .+ (t .- f.t0) .* y +end + +function (f::RODEPhi{C, T, U})(t::AbstractVector, W::AbstractVector, + θ) where {C <: Lux.AbstractExplicitLayer, T, U} +# Batch via data as row vectors +y, st = f.chain(adapt(parameterless_type(θ), [t W]'), θ, f.st) +ChainRulesCore.@ignore_derivatives f.st = st +f.u0 .+ (t' .- f.t0) .* y +end + + function (f::RODEPhi{C, T, U})(t::Number, w::Number, θ) where {C <: Optimisers.Restructure, T, U <: Number} f.u0 + (t - f.t0) * first(f.chain(θ)(adapt(parameterless_type(θ), [t, w]))) @@ -129,6 +175,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem, alg::NNRODE, args...; dt = nothing, + trajectories = 100, timeseries_errors = true, save_everystep = true, adaptive = false, @@ -153,12 +200,13 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem, init_params = alg.init_params phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + strategy = isnothing(alg.strategy) ? GridTraining(dt) : alg.strategy batch = isnothing(alg.batch) ? false : alg.batch W_prob = NoiseProblem(W, tspan) W_en = EnsembleProblem(W_prob) - W_sim = solve(W_en; dt = dt, trajectories = 100) + W_sim = solve(W_en; dt = dt, trajectories = trajectories) W_bf = Zygote.Buffer(rand(length(W_sim), length(W_sim[1]))) for (i, sol) in enumerate(W_sim) W_bf[i, :] = sol @@ -175,32 +223,5 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem, optprob = OptimizationProblem(optf, init_params) res = solve(optprob, opt; callback, maxiters, alg.kwargs...) - #solutions at timepoints - # if saveat isa Number - # ts = tspan[1]:saveat:tspan[2] - # elseif saveat isa AbstractArray - # ts = saveat - # elseif dt !== nothing - # ts = tspan[1]:dt:tspan[2] - # elseif save_everystep - # ts = range(tspan[1], tspan[2], length = 100) - # else - # ts = [tspan[1], tspan[2]] - # end - - # if u0 isa Number - # u = [first(phi(t, res.u)) for t in ts] - # else - # u = [phi(t, res.u) for t in ts] - # end - - # sol = DiffEqBase.build_solution(prob, alg, ts, u; - # k = res, dense = true, - # interp = NNODEInterpolation(phi, res.u), - # calculate_error = false, - # retcode = :Success) - # DiffEqBase.has_analytic(prob.f) && - # DiffEqBase.calculate_solution_errors!(sol; timeseries_errors = true, - # dense_errors = false) - res, u(t, W) -> phi(t, W, res.u) + res, (t, W) -> phi(t, W, res.u) end #solve From 989860dca25ec5c2fe10b5c887c195c867f26926 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 16 Aug 2022 03:47:56 -0400 Subject: [PATCH 4/6] Update tests --- test/NNRODE_tests.jl | 82 +++++++++++++++++++++++++++++++++++--------- 1 file changed, 65 insertions(+), 17 deletions(-) diff --git a/test/NNRODE_tests.jl b/test/NNRODE_tests.jl index da8f87ec7f..fa5cd017c7 100644 --- a/test/NNRODE_tests.jl +++ b/test/NNRODE_tests.jl @@ -1,6 +1,8 @@ using Flux, OptimizationOptimisers, StochasticDiffEq, DiffEqNoiseProcess, Optim, Test +import Lux, OptimizationOptimJL using NeuralPDE + using Random Random.seed!(100) @@ -11,15 +13,40 @@ u0 = 1.0f0 dt = 1 / 50.0f0 W = WienerProcess(0.0, 0.0, nothing) prob = RODEProblem(linear, u0, tspan, noise = W) +opt = OptimizationOptimisers.Adam(0.01) + +W_test = solve(NoiseProblem(W, tspan), dt = dt) +prob1 = RODEProblem(linear, u0, tspan, noise = W_test) +analytical_sol = solve(prob1, RandomEM(), dt = dt) + +ts = tspan[1]:dt:tspan[2] + chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) -opt = OptimizationOptimisers.Adam(1e-4) -sol = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, - abstol = 1e-10, maxiters = 3000) -W2 = NoiseWrapper(sol.W) -prob1 = RODEProblem(linear, u0, tspan, noise = W2) -sol2 = solve(prob1, RandomEM(), dt = dt) -err = Flux.mse(sol.u, sol2.u) -@test err < 0.3 +sol1 = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + +chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) +sol2 = solve(prob, NeuralPDE.NNRODE(chain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + + +luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) +sol1 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + +luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) +sol2 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + + println("Test Case 2") linear = (u, p, t, W) -> t^3 + 2 * t + (t^2) * ((1 + 3 * (t^2)) / (1 + t + (t^3))) - @@ -29,12 +56,33 @@ u0 = 1.0f0 dt = 1 / 100.0f0 W = WienerProcess(0.0, 0.0, nothing) prob = RODEProblem(linear, u0, tspan, noise = W) -chain = Flux.Chain(Dense(2, 32, sigmoid), Dense(32, 32, sigmoid), Dense(32, 1)) -opt = OptimizationOptimisers.Adam(1e-3) -sol = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, - abstol = 1e-10, maxiters = 2000) -W2 = NoiseWrapper(sol.W) -prob1 = RODEProblem(linear, u0, tspan, noise = W2) -sol2 = solve(prob1, RandomEM(), dt = dt) -err = Flux.mse(sol.u, sol2.u) -@test err < 0.4 +W_test = solve(NoiseProblem(W, tspan), dt = dt) +prob1 = RODEProblem(linear, u0, tspan, noise = W_test) +analytical_sol = solve(prob1, RandomEM(), dt = dt) + +ts = tspan[1]:dt:tspan[2] + +chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) +sol1 = solve(prob, NeuralPDE.NNRODE(chain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + +chain = Flux.Chain(Dense(2, 8, relu), Dense(8, 16, relu), Dense(16, 1)) +sol2 = solve(prob, NeuralPDE.NNRODE(chain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) +err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) +@test err < 0.3 + + luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) + sol1 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) + err = Flux.mse(vec(sol1[2](collect(ts), W_test.u)), analytical_sol.u) + @test err < 0.3 + + luxchain = Lux.Chain(Lux.Dense(2, 8, relu), Lux.Dense(8, 16, relu), Lux.Dense(16, 1)) + sol2 = solve(prob, NeuralPDE.NNRODE(luxchain, W, opt, batch = true), dt = dt, verbose = true, + abstol = 1e-10, maxiters = 500) + err = Flux.mse(vec(sol2[2](collect(ts), W_test.u)), analytical_sol.u) + @test err < 0.3 + From 4adf7a470467a1eede08edbbc3cac221b2c9c010 Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 16 Aug 2022 04:34:15 -0400 Subject: [PATCH 5/6] =?UTF-8?q?Rename=20generate=5Fphi=5F=CE=B8=20to=20gen?= =?UTF-8?q?erate=5Fphi=5F=CE=B8=5Frode?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rode_solve.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rode_solve.jl b/src/rode_solve.jl index cf0a818c67..8212b8d76d 100644 --- a/src/rode_solve.jl +++ b/src/rode_solve.jl @@ -30,22 +30,22 @@ mutable struct RODEPhi{C, T, U, S} end end -function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing) +function generate_phi_θ_rode(chain::Lux.AbstractExplicitLayer, t, u0, init_params::Nothing) θ, st = Lux.setup(Random.default_rng(), chain) RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(θ) end -function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) +function generate_phi_θ_rode(chain::Lux.AbstractExplicitLayer, t, u0, init_params) θ, st = Lux.setup(Random.default_rng(), chain) RODEPhi(chain, t, u0, st), ComponentArrays.ComponentArray(init_params) end -function generate_phi_θ(chain::Flux.Chain, t, u0, init_params::Nothing) +function generate_phi_θ_rode(chain::Flux.Chain, t, u0, init_params::Nothing) θ, re = Flux.destructure(chain) RODEPhi(re, t, u0), θ end -function generate_phi_θ(chain::Flux.Chain, t, u0, init_params) +function generate_phi_θ_rode(chain::Flux.Chain, t, u0, init_params) θ, re = Flux.destructure(chain) RODEPhi(re, t, u0), init_params end @@ -199,7 +199,7 @@ function DiffEqBase.__solve(prob::DiffEqBase.AbstractRODEProblem, #train points generation init_params = alg.init_params - phi, init_params = generate_phi_θ(chain, t0, u0, init_params) + phi, init_params = generate_phi_θ_rode(chain, t0, u0, init_params) strategy = isnothing(alg.strategy) ? GridTraining(dt) : alg.strategy batch = isnothing(alg.batch) ? false : alg.batch From 90ef673d67bb18ccc162fab3f52aa961abefe30a Mon Sep 17 00:00:00 2001 From: ashutosh-b-b Date: Tue, 16 Aug 2022 06:50:03 -0400 Subject: [PATCH 6/6] Enable RODE tests --- test/runtests.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index d9846cd54b..e249d83a5c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -35,12 +35,9 @@ end @time @safetestset "AdaptiveLoss" begin include("adaptive_loss_tests.jl") end end - #= - # Fails because it uses sciml_train if GROUP == "All" || GROUP == "NNRODE" @time @safetestset "NNRODE" begin include("NNRODE_tests.jl") end end - =# if GROUP == "All" || GROUP == "Forward" @time @safetestset "Forward" begin include("forward_tests.jl") end