From fadb032e4b385637f4d890a3cb2dce4af1823b8e Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 18 Jun 2024 06:07:27 +0000 Subject: [PATCH 1/5] feat: compatibility of NNODE with CUDA --- src/NeuralPDE.jl | 5 +- src/ode_solve.jl | 154 ++++++++++++++++++++++++----------------------- 2 files changed, 82 insertions(+), 77 deletions(-) diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index 1122afc838..17e2df074e 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -32,7 +32,10 @@ using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays using Lux: FromFluxAdaptor -using ChainRulesCore: @non_differentiable +using ChainRulesCore: @ignore_derivatives +using LuxDeviceUtils: LuxCUDADevice, LuxCPUDevice, cpu_device +using LuxCUDA: CuArray, CUDABackend +using KernelAbstractions: @kernel, @Const, @index RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/ode_solve.jl b/src/ode_solve.jl index 64d7b3ac6c..232360dab5 100644 --- a/src/ode_solve.jl +++ b/src/ode_solve.jl @@ -1,7 +1,7 @@ abstract type NeuralPDEAlgorithm <: SciMLBase.AbstractODEAlgorithm end """ - NNODE(chain, opt, init_params = nothing; autodiff = false, batch = 0, additional_loss = nothing, kwargs...) + NNODE(chain, opt, init_params = nothing; autodiff = false, batch = true, additional_loss = nothing, kwargs...) Algorithm for solving ordinary differential equations using a neural network. This is a specialization of the physics-informed neural network which is used as a solver for a standard `ODEProblem`. @@ -21,6 +21,7 @@ of the physics-informed neural network which is used as a solver for a standard which thus uses the random initialization provided by the neural network library. ## Keyword Arguments + * `additional_loss`: A function additional_loss(phi, θ) where phi are the neural network trial solutions, θ are the weights of the neural network(s). * `autodiff`: The switch between automatic and numerical differentiation for @@ -71,7 +72,7 @@ is an accurate interpolation (up to the neural network training result). In addi Lagaris, Isaac E., Aristidis Likas, and Dimitrios I. Fotiadis. "Artificial neural networks for solving ordinary and partial differential equations." IEEE Transactions on Neural Networks 9, no. 5 (1998): 987-1000. """ -struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, +struct NNODE{C, O, P, B, PE, K, D, AL <: Union{Nothing, Function}, S <: Union{Nothing, AbstractTrainingStrategy} } <: NeuralPDEAlgorithm @@ -83,15 +84,33 @@ struct NNODE{C, O, P, B, PE, K, AL <: Union{Nothing, Function}, strategy::S param_estim::PE additional_loss::AL + device::D kwargs::K end function NNODE(chain, opt, init_params = nothing; strategy = nothing, - autodiff = false, batch = true, param_estim = false, additional_loss = nothing, kwargs...) + autodiff = false, batch = true, param_estim = false, + additional_loss = nothing, device = cpu_device(), kwargs...) !(chain isa Lux.AbstractExplicitLayer) && (chain = adapt(FromFluxAdaptor(false, false), chain)) NNODE(chain, opt, init_params, autodiff, batch, - strategy, param_estim, additional_loss, kwargs) + strategy, param_estim, additional_loss, device, kwargs) +end + +@kernel function custom_broadcast!(f, du, @Const(out), @Const(p), @Const(t)) + i = @index(Global, Linear) + @views @inbounds x = f(out[:, i], p, t[i]) + du[:, i] .= x +end + +gpu_broadcast = custom_broadcast!(CUDABackend()) + +function get_array_type(::LuxCUDADevice) + CuArray +end + +function get_array_type(::LuxCPUDevice) + Array end """ @@ -100,53 +119,41 @@ end Internal struct, used for representing the ODE solution as a neural network in a form that respects boundary conditions, i.e. `phi(t) = u0 + t*NN(t)`. """ -mutable struct ODEPhi{C, T, U, S} +mutable struct ODEPhi{C, T, U, S, D} chain::C t0::T u0::U st::S - function ODEPhi(chain::Lux.AbstractExplicitLayer, t::Number, u0, st) - new{typeof(chain), typeof(t), typeof(u0), typeof(st)}(chain, t, u0, st) + device::D + function ODEPhi(chain::Lux.AbstractExplicitLayer, t0::Number, u0, st, device) + new{typeof(chain), typeof(t0), typeof(u0), typeof(st), typeof(device)}( + chain, t0, u0, st, device) end end -function generate_phi_θ(chain::Lux.AbstractExplicitLayer, t, u0, init_params) +function generate_phi_θ( + chain::Lux.AbstractExplicitLayer, t0, u0, init_params, device, p, param_estim) θ, st = Lux.setup(Random.default_rng(), chain) isnothing(init_params) && (init_params = θ) - ODEPhi(chain, t, u0, st), init_params -end - -function (f::ODEPhi{C, T, U})(t::Number, - θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} - y, st = f.chain( - adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), [t]), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 + (t - f.t0) * first(y) -end - -function (f::ODEPhi{C, T, U})(t::AbstractVector, - θ) where {C <: Lux.AbstractExplicitLayer, T, U <: Number} - # Batch via data as row vectors - y, st = f.chain( - adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), t'), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t' .- f.t0) .* y -end - -function (f::ODEPhi{C, T, U})(t::Number, θ) where {C <: Lux.AbstractExplicitLayer, T, U} - y, st = f.chain( - adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), [t]), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t .- f.t0) .* y + array_type = get_array_type(device) + init_params = if param_estim + ComponentArrays.ComponentArray(; + depvar = init_params, p = p) + else + ComponentArrays.ComponentArray(; + depvar = init_params) + end + u0_ = u0 isa Number ? u0 : array_type(u0) + ODEPhi(chain, t0, u0_, st, device), adapt(array_type, init_params) end -function (f::ODEPhi{C, T, U})(t::AbstractVector, - θ) where {C <: Lux.AbstractExplicitLayer, T, U} +function (f::ODEPhi{C, T, U})( + t::AbstractVector, θ) where {C <: Lux.AbstractExplicitLayer, T, U} # Batch via data as row vectors y, st = f.chain( adapt(parameterless_type(ComponentArrays.getdata(θ.depvar)), t'), θ.depvar, f.st) - ChainRulesCore.@ignore_derivatives f.st = st - f.u0 .+ (t' .- f.t0) .* y + @ignore_derivatives f.st = st + f.u0 .+ (t .- f.t0)' .* y end """ @@ -190,34 +197,37 @@ Simple L2 inner loss at a time `t` with parameters `θ` of the neural network. function inner_loss end function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, param_estim::Bool) where {C, T, U <: Number} - p_ = param_estim ? θ.p : p - sum(abs2, ode_dfdx(phi, t, θ, autodiff) - f(phi(t, θ), p_, t)) + p, param_estim::Bool) where {C, T, U} + array_type = get_array_type(phi.device) + p = param_estim ? θ.p : p + p = p isa SciMLBase.NullParameters ? p : array_type(p) + t = array_type([t]) + dxdtguess = ode_dfdx(phi, t, θ, autodiff) + out = phi(t, θ) + fs = rhs(phi.device, f, phi.u0, out, p, t) + sum(abs2, dxdtguess .- fs) end function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, param_estim::Bool) where {C, T, U <: Number} - p_ = param_estim ? θ.p : p + p, param_estim::Bool) where {C, T, U} + array_type = get_array_type(phi.device) + t = array_type(t) + p = param_estim ? θ.p : p + p = p isa SciMLBase.NullParameters ? p : array_type(p) out = phi(t, θ) - fs = reduce(hcat, [f(out[i], p_, t[i]) for i in axes(out, 2)]) - dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) + fs = rhs(phi.device, f, phi.u0, out, p, t) + dxdtguess = ode_dfdx(phi, t, θ, autodiff) sum(abs2, dxdtguess .- fs) / length(t) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::Number, θ, - p, param_estim::Bool) where {C, T, U} - p_ = param_estim ? θ.p : p - sum(abs2, ode_dfdx(phi, t, θ, autodiff) .- f(phi(t, θ), p_, t)) +function rhs(::LuxCPUDevice, f, u0, out, p, t) + u0 isa Number ? reduce(hcat, [f(out[i], p, t[i]) for i in axes(out, 2)]) : + reduce(hcat, [f(out[:, i], p, t[i]) for i in axes(out, 2)]) end -function inner_loss(phi::ODEPhi{C, T, U}, f, autodiff::Bool, t::AbstractVector, θ, - p, param_estim::Bool) where {C, T, U} - p_ = param_estim ? θ.p : p - out = Array(phi(t, θ)) - arrt = Array(t) - fs = reduce(hcat, [f(out[:, i], p_, arrt[i]) for i in 1:size(out, 2)]) - dxdtguess = Array(ode_dfdx(phi, t, θ, autodiff)) - sum(abs2, dxdtguess .- fs) / length(t) +function rhs(::LuxCUDADevice, f, u0, out, p, t) + du = similar(out) + gpu_broadcast(f, du, out, p, t; workgroupsize = 64, ndrange = 100) end """ @@ -323,8 +333,10 @@ struct NNODEInterpolation{T <: ODEPhi, T2} phi::T θ::T2 end -(f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ) -(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = f.phi(t, f.θ)[idxs] +function (f::NNODEInterpolation)(t, idxs::Nothing, ::Type{Val{0}}, p, continuity) + vec(f.phi([t], f.θ)) +end +(f::NNODEInterpolation)(t, idxs, ::Type{Val{0}}, p, continuity) = vec(f.phi([t], f.θ))[idxs] function (f::NNODEInterpolation)(t::Vector, idxs::Nothing, ::Type{Val{0}}, p, continuity) out = f.phi(t, f.θ) @@ -358,36 +370,25 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, p = prob.p t0 = tspan[1] param_estim = alg.param_estim - - #hidden layer chain = alg.chain opt = alg.opt autodiff = alg.autodiff - - #train points generation init_params = alg.init_params + device = alg.device !(chain isa Lux.AbstractExplicitLayer) && error("Only Lux.AbstractExplicitLayer neural networks are supported") - phi, init_params = generate_phi_θ(chain, t0, u0, init_params) - ((eltype(eltype(init_params).types[1]) <: Complex || - eltype(eltype(init_params).types[2]) <: Complex) && + phi, init_params = generate_phi_θ(chain, t0, u0, init_params, device, p, param_estim) + + (eltype(init_params) <: Complex && alg.strategy isa QuadratureTraining) && error("QuadratureTraining cannot be used with complex parameters. Use other strategies.") - init_params = if alg.param_estim - ComponentArrays.ComponentArray(; - depvar = ComponentArrays.ComponentArray(init_params), p = prob.p) - else - ComponentArrays.ComponentArray(; - depvar = ComponentArrays.ComponentArray(init_params)) - end - isinplace(prob) && throw(error("The NNODE solver only supports out-of-place ODE definitions, i.e. du=f(u,p,t).")) try - phi(t0, init_params) + phi(get_array_type(device)([t0]), init_params) catch err if isa(err, DimensionMismatch) throw(DimensionMismatch("Dimensions of the initial u0 and chain should match")) @@ -473,10 +474,11 @@ function SciMLBase.__solve(prob::SciMLBase.AbstractODEProblem, ts = [tspan[1], tspan[2]] end + u = phi(ts, res.u) if u0 isa Number - u = [first(phi(t, res.u)) for t in ts] + u = vec(u) else - u = [phi(t, res.u) for t in ts] + u = [u[:, i] for i in 1:size(u, 2)] end sol = SciMLBase.build_solution(prob, alg, ts, u; From e625c94a000d669940b919daaacd9674e20061e7 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 18 Jun 2024 06:07:52 +0000 Subject: [PATCH 2/5] build: add KA, LuxDeviceUtils, LuxCUDA --- Project.toml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/Project.toml b/Project.toml index f02da7c7ca..62b60f10ec 100644 --- a/Project.toml +++ b/Project.toml @@ -17,9 +17,12 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" Integrals = "de52edbc-65ea-441a-8357-d3a637375a31" +KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda" +LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553" MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" @@ -54,11 +57,13 @@ Flux = "0.14.11" ForwardDiff = "0.10.36" Functors = "0.4.4" Integrals = "4.4" +KernelAbstractions = "0.9" LineSearches = "7.2" LinearAlgebra = "1" LogDensityProblems = "2" Lux = "0.5.22" LuxCUDA = "0.3.2" +LuxDeviceUtils = "0.1" MCMCChains = "6" MethodOfLines = "0.11" ModelingToolkit = "9.9" From 5b3085727a0ea1384602bc0f75e23dabaac78626 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Tue, 18 Jun 2024 08:55:00 +0000 Subject: [PATCH 3/5] test: update NNODE tests - forward pass in additional loss --- test/NNODE_tests.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/NNODE_tests.jl b/test/NNODE_tests.jl index 0cd688e310..7fe631ee33 100644 --- a/test/NNODE_tests.jl +++ b/test/NNODE_tests.jl @@ -190,7 +190,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, strategy = GridTraining(0.01), additional_loss = additional_loss) @@ -203,7 +203,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, additional_loss = additional_loss) sol1 = solve(prob, alg1, verbose = false, abstol = 1e-10, maxiters = 200) @@ -215,7 +215,7 @@ end luxchain = Lux.Chain(Lux.Dense(1, 5, Lux.σ), Lux.Dense(5, 1)) (u_, t_) = (u_analytical(ts), ts) function additional_loss(phi, θ) - return sum(sum(abs2, [phi(t, θ) for t in t_] .- u_)) / length(u_) + return sum(sum(abs2, vec(phi(t_, θ)) .- u_)) / length(u_) end alg1 = NNODE(luxchain, opt, strategy = StochasticTraining(1000), additional_loss = additional_loss) From 2cda52d78e5eb3661dbe8a0d27ad539bb2f6ebe6 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sat, 29 Jun 2024 11:24:43 +0000 Subject: [PATCH 4/5] build: bump compats --- Project.toml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 62b60f10ec..e1d6992975 100644 --- a/Project.toml +++ b/Project.toml @@ -46,24 +46,24 @@ AdvancedHMC = "0.6.1" Aqua = "0.8" ArrayInterface = "7.9" CUDA = "5.2" -ChainRulesCore = "1.21" -ComponentArrays = "0.15.8" +ChainRulesCore = "1.24" +ComponentArrays = "0.15.14" Cubature = "1.5" DiffEqNoiseProcess = "5.20" Distributions = "0.25.107" -DocStringExtensions = "0.9" +DocStringExtensions = "0.9.3" DomainSets = "0.6, 0.7" Flux = "0.14.11" ForwardDiff = "0.10.36" -Functors = "0.4.4" +Functors = "0.4.10" Integrals = "4.4" -KernelAbstractions = "0.9" +KernelAbstractions = "0.9.22" LineSearches = "7.2" LinearAlgebra = "1" LogDensityProblems = "2" -Lux = "0.5.22" +Lux = "0.5.57" LuxCUDA = "0.3.2" -LuxDeviceUtils = "0.1" +LuxDeviceUtils = "0.1.24" MCMCChains = "6" MethodOfLines = "0.11" ModelingToolkit = "9.9" From 58d617eae533f66b9006e531b31e8419fb574fd0 Mon Sep 17 00:00:00 2001 From: Sathvik Bhagavan Date: Sat, 29 Jun 2024 11:41:24 +0000 Subject: [PATCH 5/5] fixup! build: bump compats --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index e1d6992975..82d4066745 100644 --- a/Project.toml +++ b/Project.toml @@ -45,7 +45,7 @@ Adapt = "4" AdvancedHMC = "0.6.1" Aqua = "0.8" ArrayInterface = "7.9" -CUDA = "5.2" +CUDA = "5.3" ChainRulesCore = "1.24" ComponentArrays = "0.15.14" Cubature = "1.5"