diff --git a/src/NeuralPDE.jl b/src/NeuralPDE.jl index a2ffc2370a..969892ec8d 100644 --- a/src/NeuralPDE.jl +++ b/src/NeuralPDE.jl @@ -19,8 +19,9 @@ using RuntimeGeneratedFunctions using Statistics using ArrayInterface import Optim -using Symbolics: wrap, unwrap, arguments, operation -using SymbolicUtils +using Symbolics: wrap, unwrap, arguments, operation, symtype, @arrayop, Arr +using SymbolicUtils.Code +using SymbolicUtils: Prewalk, Postwalk, Chain using AdvancedHMC, LogDensityProblems, LinearAlgebra, Functors, MCMCChains using MonteCarloMeasurements: Particles using ModelingToolkit: value, nameof, toexpr, build_expr, expand_derivatives, Interval, @@ -32,7 +33,10 @@ using SciMLBase: @add_kwonly, parameterless_type using UnPack: @unpack import ChainRulesCore, Lux, ComponentArrays using Lux: FromFluxAdaptor, recursive_eltype -using ChainRulesCore: @non_differentiable +using ChainRulesCore: @non_differentiable, @ignore_derivatives +using PDEBase: AbstractVarEqMapping, VariableMap, cardinalize_eqs!, get_depvars, + get_indvars, differential_order +using LuxCore: stateless_apply RuntimeGeneratedFunctions.init(@__MODULE__) @@ -41,6 +45,7 @@ abstract type AbstractPINN end abstract type AbstractTrainingStrategy end include("pinn_types.jl") +include("eq_data.jl") include("symbolic_utilities.jl") include("training_strategies.jl") include("adaptive_losses.jl") @@ -48,6 +53,7 @@ include("ode_solve.jl") # include("rode_solve.jl") include("dae_solve.jl") include("transform_inf_integral.jl") +include("loss_function_generation.jl") include("discretize.jl") include("neural_adapter.jl") include("advancedHMC_MCMC.jl") diff --git a/src/discretize.jl b/src/discretize.jl index 9a40e0fe82..6290aade16 100644 --- a/src/discretize.jl +++ b/src/discretize.jl @@ -1,204 +1,24 @@ """ -Build a loss function for a PDE or a boundary condition. + generate_training_sets(domains, dx, bcs, _indvars::Array, _depvars::Array) -# Examples: System of PDEs: - -Take expressions in the form: - -[Dx(u1(x,y)) + 4*Dy(u2(x,y)) ~ 0, - Dx(u2(x,y)) + 9*Dy(u1(x,y)) ~ 0] - -to - -:((cord, θ, phi, derivative, u)->begin - #= ... =# - #= ... =# - begin - (u1, u2) = (θ.depvar.u1, θ.depvar.u2) - (phi1, phi2) = (phi[1], phi[2]) - let (x, y) = (cord[1], cord[2]) - [(+)(derivative(phi1, u, [x, y], [[ε, 0.0]], 1, u1), (*)(4, derivative(phi2, u, [x, y], [[0.0, ε]], 1, u1))) - 0, - (+)(derivative(phi2, u, [x, y], [[ε, 0.0]], 1, u2), (*)(9, derivative(phi1, u, [x, y], [[0.0, ε]], 1, u2))) - 0] - end - end - end) - -for Lux.AbstractExplicitLayer. -""" -function build_symbolic_loss_function(pinnrep::PINNRepresentation, eqs; - eq_params = SciMLBase.NullParameters(), - param_estim = false, - default_p = nothing, - bc_indvars = pinnrep.indvars, - integrand = nothing, - dict_transformation_vars = nothing, - transformation_vars = nothing, - integrating_depvars = pinnrep.depvars) - @unpack indvars, depvars, dict_indvars, dict_depvars, dict_depvar_input, - phi, derivative, integral, - multioutput, init_params, strategy, eq_params, - param_estim, default_p = pinnrep - - eltypeθ = eltype(pinnrep.flat_init_params) - - if integrand isa Nothing - loss_function = parse_equation(pinnrep, eqs) - this_eq_pair = pair(eqs, depvars, dict_depvars, dict_depvar_input) - this_eq_indvars = unique(vcat(values(this_eq_pair)...)) - else - this_eq_pair = Dict(map( - intvars -> dict_depvars[intvars] => dict_depvar_input[intvars], - integrating_depvars)) - this_eq_indvars = transformation_vars isa Nothing ? - unique(vcat(values(this_eq_pair)...)) : transformation_vars - loss_function = integrand - end - - vars = :(cord, $θ, phi, derivative, integral, u, p) - ex = Expr(:block) - if multioutput - θ_nums = Symbol[] - phi_nums = Symbol[] - for v in depvars - num = dict_depvars[v] - push!(θ_nums, :($(Symbol(:($θ), num)))) - push!(phi_nums, :($(Symbol(:phi, num)))) - end - - expr_θ = Expr[] - expr_phi = Expr[] - - acum = [0; accumulate(+, map(length, init_params))] - sep = [(acum[i] + 1):acum[i + 1] for i in 1:(length(acum) - 1)] - - for i in eachindex(depvars) - push!(expr_θ, :($θ.depvar.$(depvars[i]))) - push!(expr_phi, :(phi[$i])) - end - - vars_θ = Expr(:(=), build_expr(:tuple, θ_nums), build_expr(:tuple, expr_θ)) - push!(ex.args, vars_θ) - - vars_phi = Expr(:(=), build_expr(:tuple, phi_nums), build_expr(:tuple, expr_phi)) - push!(ex.args, vars_phi) - end - - #Add an expression for parameter symbols - if param_estim == true && eq_params != SciMLBase.NullParameters() - params_symbols = Symbol[] - expr_params = Expr[] - for (i, eq_param) in enumerate(eq_params) - push!(expr_params, :($θ.p[$((i):(i))])) - push!(params_symbols, Symbol(:($eq_param))) - end - params_eq = Expr(:(=), build_expr(:tuple, params_symbols), - build_expr(:tuple, expr_params)) - push!(ex.args, params_eq) - end - - if eq_params != SciMLBase.NullParameters() && param_estim == false - params_symbols = Symbol[] - expr_params = Expr[] - for (i, eq_param) in enumerate(eq_params) - push!(expr_params, :(ArrayInterface.allowed_getindex(p, ($i):($i)))) - push!(params_symbols, Symbol(:($eq_param))) - end - params_eq = Expr(:(=), build_expr(:tuple, params_symbols), - build_expr(:tuple, expr_params)) - push!(ex.args, params_eq) - end - - eq_pair_expr = Expr[] - for i in keys(this_eq_pair) - push!(eq_pair_expr, :($(Symbol(:cord, :($i))) = vcat($(this_eq_pair[i]...)))) - end - vcat_expr = Expr(:block, :($(eq_pair_expr...))) - vcat_expr_loss_functions = Expr(:block, vcat_expr, loss_function) # TODO rename - - if strategy isa QuadratureTraining - indvars_ex = get_indvars_ex(bc_indvars) - left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex - vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), - build_expr(:tuple, right_arg_pairs)) - else - indvars_ex = [:($:cord[[$i], :]) for (i, x) in enumerate(this_eq_indvars)] - left_arg_pairs, right_arg_pairs = this_eq_indvars, indvars_ex - vars_eq = Expr(:(=), build_expr(:tuple, left_arg_pairs), - build_expr(:tuple, right_arg_pairs)) - end - - if !(dict_transformation_vars isa Nothing) - transformation_expr_ = Expr[] - for (i, u) in dict_transformation_vars - push!(transformation_expr_, :($i = $u)) - end - transformation_expr = Expr(:block, :($(transformation_expr_...))) - vcat_expr_loss_functions = Expr(:block, transformation_expr, vcat_expr, - loss_function) - end - let_ex = Expr(:let, vars_eq, vcat_expr_loss_functions) - push!(ex.args, let_ex) - expr_loss_function = :(($vars) -> begin - $ex - end) -end - -""" - build_loss_function(eqs, indvars, depvars, phi, derivative, init_params; bc_indvars=nothing) - -Returns the body of loss function, which is the executable Julia function, for the main -equation or boundary condition. -""" -function build_loss_function(pinnrep::PINNRepresentation, eqs, bc_indvars) - @unpack eq_params, param_estim, default_p, phi, derivative, integral = pinnrep - - bc_indvars = bc_indvars === nothing ? pinnrep.indvars : bc_indvars - - expr_loss_function = build_symbolic_loss_function(pinnrep, eqs; - bc_indvars = bc_indvars, - eq_params = eq_params, - param_estim = param_estim, - default_p = default_p) - u = get_u() - _loss_function = @RuntimeGeneratedFunction(expr_loss_function) - loss_function = (cord, θ) -> begin - _loss_function(cord, θ, phi, derivative, integral, u, - default_p) - end - return loss_function -end - -""" - generate_training_sets(domains,dx,bcs,_indvars::Array,_depvars::Array) - -Returns training sets for equations and boundary condition, that is used for GridTraining +Returns training sets for equations and boundary condition, that is used for `GridTraining` strategy. """ function generate_training_sets end -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, _indvars::Array, - _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars, - dict_depvars) -end - # Generate training set in the domain and on the boundary -function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::Dict, - dict_depvars::Dict) +function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, varmap) if dx isa Array dxs = dx else dxs = fill(dx, length(domains)) end - spans = [infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)] - dict_var_span = Dict([Symbol(d.variables) => infimum(d.domain):dx:supremum(d.domain) + dict_var_span = Dict([d.variables => infimum(d.domain):dx:supremum(d.domain) for (d, dx) in zip(domains, dxs)]) - bound_args = get_argument(bcs, dict_indvars, dict_depvars) - bound_vars = get_variables(bcs, dict_indvars, dict_depvars) + bound_args = get_argument(bcs, varmap) + bound_vars = get_variables(bcs, varmap) dif = [eltypeθ[] for i in 1:size(domains)[1]] for _args in bound_vars @@ -213,7 +33,7 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D setdiff(c, d) end - dict_var_span_ = Dict([Symbol(d.variables) => bc for (d, bc) in zip(domains, bc_data)]) + dict_var_span_ = Dict([d.variables => bc for (d, bc) in zip(domains, bc_data)]) bcs_train_sets = map(bound_args) do bt span = map(b -> get(dict_var_span, b, b), bt) @@ -221,8 +41,8 @@ function generate_training_sets(domains, dx, eqs, bcs, eltypeθ, dict_indvars::D hcat(vec(map(points -> collect(points), Iterators.product(span...)))...)) end - pde_vars = get_variables(eqs, dict_indvars, dict_depvars) - pde_args = get_argument(eqs, dict_indvars, dict_depvars) + pde_vars = get_variables(eqs, varmap) + pde_args = get_argument(eqs, varmap) pde_train_set = adapt(eltypeθ, hcat(vec(map(points -> collect(points), @@ -244,25 +64,11 @@ training strategy: StochasticTraining, QuasiRandomTraining, QuadratureTraining. """ function get_bounds end -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, strategy) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) -end - -function get_bounds(domains, eqs, bcs, eltypeθ, _indvars::Array, _depvars::Array, - strategy::QuadratureTraining) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) -end - -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, - strategy::QuadratureTraining) - dict_lower_bound = Dict([Symbol(d.variables) => infimum(d.domain) for d in domains]) - dict_upper_bound = Dict([Symbol(d.variables) => supremum(d.domain) for d in domains]) - - pde_args = get_argument(eqs, dict_indvars, dict_depvars) +function get_bounds( + domains, eqs, bcs, eltypeθ, v::VariableMap, strategy::QuadratureTraining) + dict_lower_bound = Dict([d.variables => infimum(d.domain) for d in domains]) + dict_upper_bound = Dict([d.variables => supremum(d.domain) for d in domains]) + pde_args = get_argument(eqs, v) pde_lower_bounds = map(pde_args) do pd span = map(p -> get(dict_lower_bound, p, p), pd) @@ -274,7 +80,7 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, end pde_bounds = [pde_lower_bounds, pde_upper_bounds] - bound_vars = get_variables(bcs, dict_indvars, dict_depvars) + bound_vars = get_variables(bcs, v) bcs_lower_bounds = map(bound_vars) do bt map(b -> dict_lower_bound[b], bt) @@ -283,26 +89,25 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, map(b -> dict_upper_bound[b], bt) end bcs_bounds = [bcs_lower_bounds, bcs_upper_bounds] - [pde_bounds, bcs_bounds] end -function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, strategy) +function get_bounds(domains, eqs, bcs, eltypeθ, v::VariableMap, strategy) dx = 1 / strategy.points - dict_span = Dict([Symbol(d.variables) => [ + dict_span = Dict([d.variables => [ infimum(d.domain) + dx, supremum(d.domain) - dx ] for d in domains]) # pde_bounds = [[infimum(d.domain),supremum(d.domain)] for d in domains] - pde_args = get_argument(eqs, dict_indvars, dict_depvars) + pde_args = get_argument(eqs, v) pde_bounds = map(pde_args) do pde_arg bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, pde_arg) bds = eltypeθ.(bds) bds[1, :], bds[2, :] end - bound_args = get_argument(bcs, dict_indvars, dict_depvars) + bound_args = get_argument(bcs, v) bcs_bounds = map(bound_args) do bound_arg bds = mapreduce(s -> get(dict_span, s, fill(s, 2)), hcat, bound_arg) bds = eltypeθ.(bds) @@ -311,18 +116,18 @@ function get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, str return pde_bounds, bcs_bounds end +# TODO: Get this to work with varmap function get_numeric_integral(pinnrep::PINNRepresentation) - @unpack strategy, indvars, depvars, multioutput, derivative, - depvars, indvars, dict_indvars, dict_depvars = pinnrep + @unpack strategy, multioutput, derivative, varmap = pinnrep - integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, indvars = indvars, depvars = depvars, dict_indvars = dict_indvars, dict_depvars = dict_depvars) -> begin + integral = (u, cord, phi, integrating_var_id, integrand_func, lb, ub, θ; strategy = strategy, varmap = varmap) -> begin function integration_(cord, lb, ub, θ) cord_ = cord function integrand_(x, p) ChainRulesCore.@ignore_derivatives @views(cord_[integrating_var_id]) .= x return integrand_func(cord_, p, phi, derivative, nothing, u, nothing) end - prob_ = IntegralProblem(integrand_, (lb, ub), θ) + prob_ = IntegralProblem(integrand_, lb, ub, θ) sol = solve(prob_, CubatureJLh(), reltol = 1e-3, abstol = 1e-3)[1] return sol @@ -357,6 +162,10 @@ function get_numeric_integral(pinnrep::PINNRepresentation) end end +function lazyconvert(T, x::Symbolics.Arr) + Symbolics.array_term(convert, T, x, size = size(x)) +end + """ prob = symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN) @@ -369,30 +178,35 @@ which is later optimized upon to give Solution or the Solution Distribution of t For more information, see `discretize` and `PINNRepresentation`. """ -function SciMLBase.symbolic_discretize(pde_system::PDESystem, - discretization::AbstractPINN) - eqs = pde_system.eqs - bcs = pde_system.bcs - chain = discretization.chain - - domains = pde_system.domain - eq_params = pde_system.ps - defaults = pde_system.defaults +function SciMLBase.symbolic_discretize(pdesys::PDESystem, + discretization::PhysicsInformedNN) + cardinalize_eqs!(pdesys) + eqs = pdesys.eqs + bcs = pdesys.bcs + domains = pdesys.domain + eq_params = pdesys.ps + defaults = pdesys.defaults default_p = eq_params == SciMLBase.NullParameters() ? nothing : [defaults[ep] for ep in eq_params] + chain = discretization.chain param_estim = discretization.param_estim additional_loss = discretization.additional_loss adaloss = discretization.adaptive_loss - - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars( - pde_system.indvars, - pde_system.depvars) - multioutput = discretization.multioutput init_params = discretization.init_params + phi = discretization.phi + derivative = discretization.derivative + strategy = discretization.strategy + logger = discretization.logger + log_frequency = discretization.log_options.log_frequency + iteration = discretization.iteration + self_increment = discretization.self_increment + + varmap = VariableMap(pdesys, discretization) + eqdata = EquationData(pdesys, varmap, strategy) - if init_params === nothing + if isnothing(init_params) # Use the initialization of the neural network framework # But for Lux, default to Float64 # This is done because Float64 is almost always better for these applications @@ -403,8 +217,10 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, x)) Float64.(_x) # No ComponentArray GPU support end - names = ntuple(i -> depvars[i], length(chain)) - init_params = ComponentArrays.ComponentArray(NamedTuple{names}(i + # chain_names = ntuple(i -> depvars(eqs[i].lhs, eqdata), length(chain)) + # @show chain_names + chain_names = Tuple(Symbol.(operation.(unwrap.(pdesys.dvs)))) + init_params = ComponentArrays.ComponentArray(NamedTuple{chain_names}(i for i in x)) else init_params = Float64.(ComponentArrays.ComponentArray(Lux.initialparameters( @@ -415,112 +231,112 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, init_params = init_params end + if phi isa AbstractVector + chain_params_symbols = map(chain_names) do chain_name + _params = getproperty(init_params, chain_name) + [ + first(@parameters Symbol("pss_" * string(chain_name))[1:length(_params)]), + first(@parameters Symbol("T_" * string(chain_name))::typeof(typeof(_params))=typeof(_params) [tunable = false]) + ] + end + outs = [] + for i in eachindex(phi) + out = x -> stateless_apply(phi[i].f, x, + lazyconvert(chain_params_symbols[i][2], chain_params_symbols[i][1]))[1] + push!(outs, out) + end + else + chain_params_symbols = [ + first(@parameters pss[1:length(init_params)]), + first(@parameters T::typeof(typeof(init_params))=typeof(init_params) [tunable = false]) + ] + outs = [] + for i in eachindex(pdesys.dvs) + out = x -> stateless_apply( + phi.f, x, lazyconvert(chain_params_symbols[2], chain_params_symbols[1]))[i] + push!(outs, out) + end + end + + depvars_outs_map = Dict( + operation.(unwrap.(pdesys.dvs)) .=> outs + ) + flat_init_params = if init_params isa ComponentArrays.ComponentArray init_params - elseif multioutput - @assert length(init_params) == length(depvars) - names = ntuple(i -> depvars[i], length(init_params)) - x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params)) + # elseif multioutput + # # @assert length(init_params) == length(depvars) + # names = ntuple(i -> depvars(eqs[i].lhs, eqdata), length(init_params)) + # x = ComponentArrays.ComponentArray(NamedTuple{names}(i for i in init_params)) else ComponentArrays.ComponentArray(init_params) end - flat_init_params = if param_estim == false && multioutput + flat_init_params = if param_estim == false ComponentArrays.ComponentArray(; depvar = flat_init_params) - elseif param_estim == false && !multioutput - flat_init_params else ComponentArrays.ComponentArray(; depvar = flat_init_params, p = default_p) end - eltypeθ = eltype(flat_init_params) - - if adaloss === nothing - adaloss = NonAdaptiveLoss{eltypeθ}() - end - - phi = discretization.phi - - if (phi isa Vector && phi[1].f isa Lux.AbstractExplicitLayer) + if phi isa Vector for ϕ in phi ϕ.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), ϕ.st) end - elseif (!(phi isa Vector) && phi.f isa Lux.AbstractExplicitLayer) + else phi.st = adapt(parameterless_type(ComponentArrays.getdata(flat_init_params)), phi.st) end - derivative = discretization.derivative - strategy = discretization.strategy - - logger = discretization.logger - log_frequency = discretization.log_options.log_frequency - iteration = discretization.iteration - self_increment = discretization.self_increment - - if !(eqs isa Array) - eqs = [eqs] - end + # if multioutput + # # acum = [0; accumulate(+, map(length, init_params))] + # phi = map(enumerate(pdesys.dvs)) do (i, dv) + # (coord, expr_θ) -> phi[i](coord, expr_θ.depvar.$(dv)) + # end + # else + # # phimap = nothing + # phi = (coord, expr_θ) -> phi(coord, expr_θ.depvar) + # end - pde_indvars = if strategy isa QuadratureTraining - get_argument(eqs, dict_indvars, dict_depvars) - else - get_variables(eqs, dict_indvars, dict_depvars) - end + eltypeθ = eltype(flat_init_params) - bc_indvars = if strategy isa QuadratureTraining - get_argument(bcs, dict_indvars, dict_depvars) - else - get_variables(bcs, dict_indvars, dict_depvars) + if adaloss === nothing + adaloss = NonAdaptiveLoss{eltypeθ}() end - pde_integration_vars = get_integration_variables(eqs, dict_indvars, dict_depvars) - bc_integration_vars = get_integration_variables(bcs, dict_indvars, dict_depvars) + eqs = map(eq -> eq.lhs, eqs) + bcs = map(bc -> bc.lhs, bcs) pinnrep = PINNRepresentation(eqs, bcs, domains, eq_params, defaults, default_p, - param_estim, additional_loss, adaloss, depvars, indvars, - dict_indvars, dict_depvars, dict_depvar_input, logger, + param_estim, additional_loss, adaloss, varmap, logger, multioutput, iteration, init_params, flat_init_params, phi, - derivative, - strategy, pde_indvars, bc_indvars, pde_integration_vars, - bc_integration_vars, nothing, nothing, nothing, nothing) + derivative, depvars_outs_map, + strategy, eqdata, nothing, nothing, nothing, nothing) - integral = get_numeric_integral(pinnrep) + #integral = get_numeric_integral(pinnrep) - symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq; - bc_indvars = pde_indvar) - for (eq, pde_indvar) in zip(eqs, pde_indvars, - pde_integration_vars)] + #symbolic_pde_loss_functions = [build_symbolic_loss_function(pinnrep, eq) for eq in eqs] - symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc; - bc_indvars = bc_indvar) - for (bc, bc_indvar) in zip(bcs, bc_indvars, - bc_integration_vars)] + #symbolic_bc_loss_functions = [build_symbolic_loss_function(pinnrep, bc) |> toexpr for bc in bcs] - pinnrep.integral = integral - pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions - pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions + #pinnrep.integral = integral + #pinnrep.symbolic_pde_loss_functions = symbolic_pde_loss_functions + #pinnrep.symbolic_bc_loss_functions = symbolic_bc_loss_functions - datafree_pde_loss_functions = [build_loss_function(pinnrep, eq, pde_indvar) - for (eq, pde_indvar, integration_indvar) in zip(eqs, - pde_indvars, - pde_integration_vars)] - - datafree_bc_loss_functions = [build_loss_function(pinnrep, bc, bc_indvar) - for (bc, bc_indvar, integration_indvar) in zip(bcs, - bc_indvars, - bc_integration_vars)] + datafree_pde_loss_functions = [build_loss_function(pinnrep, eq) for eq in eqs] + datafree_bc_loss_functions = [build_loss_function(pinnrep, bc) for bc in bcs] pde_loss_functions, bc_loss_functions = merge_strategy_with_loss_function(pinnrep, strategy, datafree_pde_loss_functions, datafree_bc_loss_functions) + # setup for all adaptive losses num_pde_losses = length(pde_loss_functions) num_bc_losses = length(bc_loss_functions) + # assume one single additional loss function if there is one. this means that the user needs to lump all their functions into a single one, - num_additional_loss = additional_loss isa Nothing ? 0 : 1 + num_additional_loss = isnothing(additional_loss) ? 0 : 1 adaloss_T = eltype(adaloss.pde_loss_weights) @@ -706,14 +522,14 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, end """ - prob = discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) + discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) Transforms a symbolic description of a ModelingToolkit-defined `PDESystem` and generates an `OptimizationProblem` for [Optimization.jl](https://docs.sciml.ai/Optimization/stable/) whose solution is the solution to the PDE. """ -function SciMLBase.discretize(pde_system::PDESystem, discretization::PhysicsInformedNN) - pinnrep = symbolic_discretize(pde_system, discretization) +function SciMLBase.discretize(pdesys::PDESystem, discretization::PhysicsInformedNN) + pinnrep = symbolic_discretize(pdesys, discretization) f = OptimizationFunction(pinnrep.loss_functions.full_loss_function, Optimization.AutoZygote()) Optimization.OptimizationProblem(f, pinnrep.flat_init_params) diff --git a/src/eq_data.jl b/src/eq_data.jl new file mode 100644 index 0000000000..4d059fbea2 --- /dev/null +++ b/src/eq_data.jl @@ -0,0 +1,97 @@ +struct EquationData <: AbstractVarEqMapping + depvarmap::Any + indvarmap::Any + args::Any + ivargs::Any + argmap::Any +end + +function EquationData(pdesys, varmap, strategy) + eqs = map(eq -> eq.lhs, pdesys.eqs) + bcs = map(eq -> eq.lhs, pdesys.bcs) + alleqs = vcat(eqs, bcs) + + argmap = map(alleqs) do eq + eq => get_argument([eq], varmap)[1] + end |> Dict + + depvarmap = map(alleqs) do eq + eq => get_depvars(eq, varmap.depvar_ops) + end |> Dict + + indvarmap = map(alleqs) do eq + eq => get_indvars(eq, varmap) + end |> Dict + + # Why? + if strategy isa QuadratureTraining + _args = get_argument(alleqs, varmap) + else + _args = get_variables(alleqs, varmap) + end + + args = map(zip(alleqs, _args)) do (eq, args) + eq => args + end |> Dict + + ivargs = get_iv_argument(alleqs, varmap) + + ivargs = map(zip(alleqs, ivargs)) do (eq, args) + eq => args + end |> Dict + + EquationData(depvarmap, indvarmap, args, ivargs, argmap) +end + +function depvars(eq, eqdata::EquationData) + eqdata.depvarmap[eq] +end + +function indvars(eq, eqdata::EquationData) + eqdata.indvarmap[eq] +end + +function eq_args(eq, eqdata::EquationData) + eqdata.args[eq] +end + +function eq_iv_args(eq, eqdata::EquationData) + eqdata.ivargs[eq] +end + +argument(eq, eqdata) = eqdata.argmap[eq] + +function get_iv_argument(eqs, v::VariableMap) + vars = map(eqs) do eq + _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) + f_vars = filter(x -> !isempty(x), _vars) + mapreduce(vars -> mapreduce(op -> v.args[op], vcat, operation.(vars), init = []), + vcat, f_vars, init = []) + end + args_ = map(vars) do _vars + seen = [] + filter(_vars) do x + if x isa Number + error("Unreachable") + else + if any(isequal(x), seen) + false + else + push!(seen, x) + true + end + end + end + end + return args_ +end + +""" + get_iv_variables(eqs, v::VariableMap) + +Returns all variables that are used in each equations or boundary condition. +""" +function get_iv_variables(eqs, v::VariableMap) + args = get_iv_argument(eqs, v) + return map(arg -> filter(x -> !(x isa Number), arg), args) +end diff --git a/src/loss_function_generation.jl b/src/loss_function_generation.jl new file mode 100644 index 0000000000..ae7d7e4ced --- /dev/null +++ b/src/loss_function_generation.jl @@ -0,0 +1,182 @@ +# TODO: add multioutput +# TODO: add integrals + +function build_symbolic_loss_function(pinnrep::PINNRepresentation, eq; + eq_params = SciMLBase.NullParameters(), + param_estim = false, + default_p = [], + integrand = nothing, + transformation_vars = nothing) + @unpack varmap, eqdata, + phi, derivative, integral, + multioutput, init_params, strategy, + param_estim, default_p = pinnrep + + eltypeθ = eltype(pinnrep.flat_init_params) + + eq = eq isa Equation ? eq.lhs : eq + + eq_args = get(eqdata.ivargs, eq, varmap.x̄) + + if isnothing(integrand) + this_eq_indvars = indvars(eq, eqdata) + this_eq_depvars = depvars(eq, eqdata) + loss_function = parse_equation(pinnrep, eq, eq_iv_args(eq, eqdata)) + else + this_eq_indvars = transformation_vars isa Nothing ? + unique(indvars(eq, eqmap)) : transformation_vars + loss_function = integrand + end + + n = length(this_eq_indvars) + + get_ps = if param_estim == true && !isnothing(default_p) + (θ) -> θ.p[1:length(eq_params)] + else + (θ) -> default_p + end + + function get_coords(cord) + num_numbers = 0 + out = map(enumerate(eq_args)) do (i, x) + if x isa Number + fill(convert(eltypeθ, x), size(cord[[1], :])) + else + cord[[i], :] + end + end + if out === nothing + return [] + else + return out + end + end + + full_loss_func = (cord, θ, phi, p) -> begin + coords = [[nothing]] + @ignore_derivatives coords = get_coords(cord) + @show coords + loss_function(coords, θ, phi, get_ps(θ)) + end + return full_loss_func +end + +@register_array_symbolic (f::Phi{<:Lux.AbstractExplicitLayer})( + x::AbstractArray, ps::Union{NamedTuple, <:AbstractVector}) begin + size = LuxCore.outputsize(f.f, x, LuxCore._default_rng()) + eltype = Real +end + +function build_loss_function(pinnrep, eq) + @unpack eq_params, param_estim, default_p, phi, multioutput, derivative, integral = pinnrep + _loss_function = build_symbolic_loss_function(pinnrep, eq, + eq_params = eq_params, + param_estim = param_estim) + loss_function = (cord, θ) -> begin + _loss_function(cord, θ, phi, + default_p) + end + return loss_function +end + +function operations(ex) + if istree(ex) + op = operation(ex) + return vcat(operations.(arguments(ex))..., op) + end + return [] +end + +############################################################################################ +# Parse equation +############################################################################################ + +function parse_equation(pinnrep::PINNRepresentation, term, ivs; is_integral = false, + dict_transformation_vars = nothing, + transformation_vars = nothing) + @unpack varmap, eqdata, derivative, integral, flat_init_params, phi, depvars_outs_map, = pinnrep + eltypeθ = eltype(flat_init_params) + + ex_vars = get_depvars(term, varmap.depvar_ops) + + # if multioutput + # dummyvars = @variables switch + # else + # dummyvars = @variables switch + # end + dummyvars = @variables switch + + dummyvars = unwrap.(dummyvars) + deriv_rules = generate_derivative_rules( + term, eqdata, eltypeθ, dummyvars, derivative, varmap, depvars_outs_map) + ch = Prewalk(Chain(deriv_rules)) + + expr = ch(term) + #expr = swch(expr) + + sym_coords = DestructuredArgs(ivs) + ps = DestructuredArgs(varmap.ps) + + args = [sym_coords, ps] + + ex = Func(args, [], expr) |> toexpr |> _dot_ + + @show ex + f = @RuntimeGeneratedFunction ex + return f +end + +function generate_derivative_rules( + term, eqdata, eltypeθ, dummyvars, derivative, varmap, depvars_outs_map) + switch = dummyvars + # if symtype(phi) isa AbstractArray + # phi = collect(phi) + # end + + dvs = get_depvars(term, varmap.depvar_ops) + + # Orthodox derivatives + n(w) = length(arguments(w)) + rs = reduce(vcat, + [reduce(vcat, + [[@rule $((Differential(x)^d)(w)) => derivative( + depvars_outs_map[operation(w)], arguments(w), + get_ε(n(w), j, eltypeθ, d), + d, θ) + for d in differential_order(term, x)] + for (j, x) in enumerate(varmap.args[operation(w)])], + init = []) + for w in dvs], + init = []) + + # Mixed derivatives + mx = mapreduce(vcat, dvs, init = []) do w + mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (j, x) + mapreduce(vcat, enumerate(varmap.args[operation(w)]), init = []) do (k, y) + if isequal(x, y) + [(_) -> nothing] + else + ε1 = get_ε(n(w), j, eltypeθ, 1) + ε2 = get_ε(n(w), k, eltypeθ, 1) + [@rule $((Differential(x))((Differential(y))(w))) => derivative( + (coord_, θ_) -> derivative( + depvars_outs_map[operation(w)], arguments(w), + ε2, 1, θ_), + arguments(w), ε1, 1, θ)] + end + end + end + end + + vr = mapreduce(vcat, dvs, init = []) do w + @rule w => depvars_outs_map[operation(w)](arguments(w)) + end + + return [mx; rs; vr] +end + +function generate_integral_rules(eq, eqdata, dummyvars) + phi, u, θ = dummyvars + #! all that should be needed is to solve an integral problem, the trick is doing this + #! with rules without putting symbols through the solve +end diff --git a/src/pinn_types.jl b/src/pinn_types.jl index 59480d8a60..c2314b6779 100644 --- a/src/pinn_types.jl +++ b/src/pinn_types.jl @@ -357,23 +357,7 @@ mutable struct PINNRepresentation """ The dependent variables of the system """ - depvars::Any - """ - The independent variables of the system - """ - indvars::Any - """ - A dictionary form of the independent variables. Define the structure ??? - """ - dict_indvars::Any - """ - A dictionary form of the dependent variables. Define the structure ??? - """ - dict_depvars::Any - """ - ??? - """ - dict_depvar_input::Any + varmap::Any """ The logger as provided by the user """ @@ -412,25 +396,17 @@ mutable struct PINNRepresentation """ derivative::Any """ - The training strategy as provided by the user - """ - strategy::AbstractTrainingStrategy - """ - ??? - """ - pde_indvars::Any + Symbols of parameters of neural networks. """ - ??? - """ - bc_indvars::Any + depvars_outs_map::Any """ - ??? + The training strategy as provided by the user """ - pde_integration_vars::Any + strategy::AbstractTrainingStrategy """ ??? """ - bc_integration_vars::Any + eqdata::Any """ ??? """ @@ -521,39 +497,76 @@ function get_u() end # the method to calculate the derivative -function numeric_derivative(phi, u, x, εs, order, θ) +function numeric_derivative(phi, x, ε, order, θ) _type = parameterless_type(ComponentArrays.getdata(θ)) - ε = εs[order] - _epsilon = inv(first(ε[ε .!= zero(ε)])) - + _epsilon = inv(first(ε[ε .!= zero(eltype(ε))])) ε = adapt(_type, ε) x = adapt(_type, x) - # any(x->x!=εs[1],εs) - # εs is the epsilon for each order, if they are all the same then we use a fancy formula - # if order 1, this is trivially true - - if order > 4 || any(x -> x != εs[1], εs) - return (numeric_derivative(phi, u, x .+ ε, @view(εs[1:(end - 1)]), order - 1, θ) - .- - numeric_derivative(phi, u, x .- ε, @view(εs[1:(end - 1)]), order - 1, θ)) .* - _epsilon ./ 2 - elseif order == 4 - return (u(x .+ 2 .* ε, θ, phi) .- 4 .* u(x .+ ε, θ, phi) + if order == 4 + return (phi(x .+ 2 .* ε, θ) .- 4 .* phi(x .+ ε, θ) .+ - 6 .* u(x, θ, phi) + 6 .* phi(x, θ) .- - 4 .* u(x .- ε, θ, phi) .+ u(x .- 2 .* ε, θ, phi)) .* _epsilon^4 + 4 .* phi(x .- ε, θ) .+ phi(x .- 2 .* ε, θ)) .* _epsilon^4 elseif order == 3 - return (u(x .+ 2 .* ε, θ, phi) .- 2 .* u(x .+ ε, θ, phi) .+ 2 .* u(x .- ε, θ, phi) + return (phi(x .+ 2 .* ε, θ) .- 2 .* phi(x .+ ε, θ) .+ 2 .* phi(x .- ε, θ) - - u(x .- 2 .* ε, θ, phi)) .* _epsilon^3 ./ 2 + phi(x .- 2 .* ε, θ)) .* _epsilon^3 ./ 2 elseif order == 2 - return (u(x .+ ε, θ, phi) .+ u(x .- ε, θ, phi) .- 2 .* u(x, θ, phi)) .* _epsilon^2 + return (phi(x .+ ε, θ) .+ phi(x .- ε, θ) .- 2 .* phi(x, θ)) .* _epsilon^2 elseif order == 1 - return (u(x .+ ε, θ, phi) .- u(x .- ε, θ, phi)) .* _epsilon ./ 2 + return (phi(x .+ ε, θ) .- phi(x .- ε, θ)) .* _epsilon ./ 2 + else + error("This shouldn't happen! Got an order of $(order).") + end +end + +#@register_symbolic(numeric_derivative(phi, x, ε, order, θ)) + +function ufunc(u, phi, v) + if symtype(phi) isa AbstractArray + return phi[findfirst(w -> isequal(operation(w), operation(u)), v.ū)] else - error("This shouldn't happen!") + return phi end end + +#= +_vcat(x::Number...) = vcat(x...) +_vcat(x::AbstractArray{<:Number}...) = vcat(x...) +function _vcat(x::Union{Number, AbstractArray{<:Number}}...) + example = first(Iterators.filter(e -> !(e isa Number), x)) + dims = (1, size(example)[2:end]...) + x = map(el -> el isa Number ? (typeof(example))(fill(el, dims)) : el, x) + _vcat(x...) +end +_vcat(x...) = vcat(x...) +https://github.com/SciML/NeuralPDE.jl/pull/627/files +=# + +function reducevcat(vector::Vector, eltypeθ) + isnothing(vector) && return [[nothing]] + if all(x -> x isa Number, vector) + return vector + else + z = findfirst(x -> !(x isa Number), vector) + return rvcat(vector, vector[z], eltypeθ) + end +end + +function rvcat(example, sym, eltypeθ) + out = map(example) do x + if x isa Number + out = convert(eltypeθ, x) + out + else + out = x + out + end + end + #out = @arrayop (i,) out[i] i in 1:length(out) + + return out +end diff --git a/src/symbolic_utilities.jl b/src/symbolic_utilities.jl index c78ddeff83..31caa33bd2 100644 --- a/src/symbolic_utilities.jl +++ b/src/symbolic_utilities.jl @@ -18,11 +18,14 @@ julia> _dot_(e) """ dottable_(x) = Broadcast.dottable(x) dottable_(x::Function) = true +dottable_(x::typeof(numeric_derivative)) = false +dottable_(x::Phi) = false _dot_(x) = x function _dot_(x::Expr) dotargs = Base.mapany(_dot_, x.args) - if x.head === :call && dottable_(x.args[1]) + nodot = [:phi, Symbol("NeuralPDE.numeric_derivative"), NeuralPDE.rvcat] + if x.head === :call && dottable_(x.args[1]) && all(s -> x.args[1] != s, nodot) Expr(:., dotargs[1], Expr(:tuple, dotargs[2:end]...)) elseif x.head === :comparison Expr(:comparison, @@ -34,7 +37,9 @@ function _dot_(x::Expr) Expr(:let, undot(dotargs[1]), dotargs[2]) elseif x.head === :for # don't add dots to for x=... assignments Expr(:for, undot(dotargs[1]), dotargs[2]) - elseif (x.head === :(=) || x.head === :function || x.head === :macro) && + elseif x.head === :(=) # don't add dots to x=... assignments + Expr(:(=), dotargs[1], dotargs[2]) + elseif (x.head === :function || x.head === :macro) && Meta.isexpr(x.args[1], :call) # function or macro definition Expr(x.head, x.args[1], dotargs[2]) elseif x.head === :(<:) || x.head === :(>:) @@ -49,7 +54,6 @@ function _dot_(x::Expr) end end end - """ Create dictionary: variable => unique number for variable @@ -114,168 +118,6 @@ where - order - order of derivative. - θ - weights in neural network. """ -function _transform_expression(pinnrep::PINNRepresentation, ex; is_integral = false, - dict_transformation_vars = nothing, - transformation_vars = nothing) - @unpack indvars, depvars, dict_indvars, dict_depvars, - dict_depvar_input, multioutput, strategy, phi, - derivative, integral, flat_init_params, init_params = pinnrep - eltypeθ = eltype(flat_init_params) - - _args = ex.args - for (i, e) in enumerate(_args) - if !(e isa Expr) - if e in keys(dict_depvars) - depvar = _args[1] - num_depvar = dict_depvars[depvar] - indvars = _args[2:end] - var_ = is_integral ? :(u) : :($(Expr(:$, :u))) - ex.args = if !multioutput - [var_, Symbol(:cord, num_depvar), :($θ), :phi] - else - [ - var_, - Symbol(:cord, num_depvar), - Symbol(:($θ), num_depvar), - Symbol(:phi, num_depvar) - ] - end - break - elseif e isa ModelingToolkit.Differential - derivative_variables = Symbol[] - order = 0 - while (_args[1] isa ModelingToolkit.Differential) - order += 1 - push!(derivative_variables, toexpr(_args[1].x)) - _args = _args[2].args - end - depvar = _args[1] - num_depvar = dict_depvars[depvar] - indvars = _args[2:end] - dict_interior_indvars = Dict([indvar .=> j - for (j, indvar) in enumerate(dict_depvar_input[depvar])]) - dim_l = length(dict_interior_indvars) - - var_ = is_integral ? :(derivative) : :($(Expr(:$, :derivative))) - εs = [get_ε(dim_l, d, eltypeθ, order) for d in 1:dim_l] - undv = [dict_interior_indvars[d_p] for d_p in derivative_variables] - εs_dnv = [εs[d] for d in undv] - - ex.args = if !multioutput - [var_, :phi, :u, Symbol(:cord, num_depvar), εs_dnv, order, :($θ)] - else - [ - var_, - Symbol(:phi, num_depvar), - :u, - Symbol(:cord, num_depvar), - εs_dnv, - order, - Symbol(:($θ), num_depvar) - ] - end - break - elseif e isa Symbolics.Integral - if _args[1].domain.variables isa Tuple - integrating_variable_ = collect(_args[1].domain.variables) - integrating_variable = toexpr.(integrating_variable_) - integrating_var_id = [dict_indvars[i] for i in integrating_variable] - else - integrating_variable = toexpr(_args[1].domain.variables) - integrating_var_id = [dict_indvars[integrating_variable]] - end - - integrating_depvars = [] - integrand_expr = _args[2] - for d in depvars - d_ex = find_thing_in_expr(integrand_expr, d) - if !isempty(d_ex) - push!(integrating_depvars, d_ex[1].args[1]) - end - end - - lb, ub = get_limits(_args[1].domain.domain) - lb, ub, _args[2], dict_transformation_vars, transformation_vars = transform_inf_integral( - lb, - ub, - _args[2], - integrating_depvars, - dict_depvar_input, - dict_depvars, - integrating_variable, - eltypeθ) - - num_depvar = map(int_depvar -> dict_depvars[int_depvar], - integrating_depvars) - integrand_ = transform_expression(pinnrep, _args[2]; - is_integral = false, - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars) - integrand__ = _dot_(integrand_) - - integrand = build_symbolic_loss_function(pinnrep, nothing; - integrand = integrand__, - integrating_depvars = integrating_depvars, - eq_params = SciMLBase.NullParameters(), - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars, - param_estim = false, - default_p = nothing) - # integrand = repr(integrand) - lb = toexpr.(lb) - ub = toexpr.(ub) - ub_ = [] - lb_ = [] - for l in lb - if l isa Number - push!(lb_, l) - else - l_expr = NeuralPDE.build_symbolic_loss_function(pinnrep, nothing; - integrand = _dot_(l), - integrating_depvars = integrating_depvars, - param_estim = false, - default_p = nothing) - l_f = @RuntimeGeneratedFunction(l_expr) - push!(lb_, l_f) - end - end - for u_ in ub - if u_ isa Number - push!(ub_, u_) - else - u_expr = NeuralPDE.build_symbolic_loss_function(pinnrep, nothing; - integrand = _dot_(u_), - integrating_depvars = integrating_depvars, - param_estim = false, - default_p = nothing) - u_f = @RuntimeGeneratedFunction(u_expr) - push!(ub_, u_f) - end - end - - integrand_func = @RuntimeGeneratedFunction(integrand) - ex.args = [ - :($(Expr(:$, :integral))), - :u, - Symbol(:cord, num_depvar[1]), - :phi, - integrating_var_id, - integrand_func, - lb_, - ub_, - :($θ) - ] - break - end - else - ex.args[i] = _transform_expression(pinnrep, ex.args[i]; - is_integral = is_integral, - dict_transformation_vars = dict_transformation_vars, - transformation_vars = transformation_vars) - end - end - return ex -end """ Parse ModelingToolkit equation form to the inner representation. @@ -343,79 +185,9 @@ function pair(eq, depvars, dict_depvars, dict_depvar_input) Dict(filter(p -> p !== nothing, pair_)) end -function get_vars(indvars_, depvars_) - indvars = ModelingToolkit.getname.(indvars_) - depvars = Symbol[] - dict_depvar_input = Dict{Symbol, Vector{Symbol}}() - for d in depvars_ - if unwrap(d) isa SymbolicUtils.BasicSymbolic - dname = ModelingToolkit.getname(d) - push!(depvars, dname) - push!(dict_depvar_input, - dname => [nameof(unwrap(argument)) - for argument in arguments(unwrap(d))]) - else - dname = ModelingToolkit.getname(d) - push!(depvars, dname) - push!(dict_depvar_input, dname => indvars) # default to all inputs if not given - end - end - - dict_indvars = get_dict_vars(indvars) - dict_depvars = get_dict_vars(depvars) - return depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input -end - -function get_integration_variables(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - get_integration_variables(eqs, dict_indvars, dict_depvars) -end - -function get_integration_variables(eqs, dict_indvars, dict_depvars) - exprs = toexpr.(eqs) - vars = map(exprs) do expr - _vars = Symbol.(filter(indvar -> length(find_thing_in_expr(expr, indvar)) > 0, - sort(collect(keys(dict_indvars))))) - end -end - -""" - get_variables(eqs,_indvars,_depvars) - -Returns all variables that are used in each equations or boundary condition. -""" -function get_variables end - -function get_variables(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - return get_variables(eqs, dict_indvars, dict_depvars) -end - -function get_variables(eqs, dict_indvars, dict_depvars) - bc_args = get_argument(eqs, dict_indvars, dict_depvars) - return map(barg -> filter(x -> x isa Symbol, barg), bc_args) -end - -function get_number(eqs, dict_indvars, dict_depvars) - bc_args = get_argument(eqs, dict_indvars, dict_depvars) - return map(barg -> filter(x -> x isa Number, barg), bc_args) -end - -function find_thing_in_expr(ex::Expr, thing; ans = []) - if thing in ex.args - push!(ans, ex) - end - for e in ex.args - if e isa Expr - if thing in e.args - push!(ans, e) - end - find_thing_in_expr(e, thing; ans = ans) - end - end - return collect(Set(ans)) +function get_integration_variables(eqs, v::VariableMap) + ivs = all_ivs(v) + return map(eq -> get_indvars(eq, ivs), eqs) end """ @@ -425,34 +197,45 @@ Returns all arguments that are used in each equations or boundary condition. """ function get_argument end -# Get arguments from boundary condition functions -function get_argument(eqs, _indvars::Array, _depvars::Array) - depvars, indvars, dict_indvars, dict_depvars, dict_depvar_input = get_vars(_indvars, - _depvars) - get_argument(eqs, dict_indvars, dict_depvars) -end -function get_argument(eqs, dict_indvars, dict_depvars) - exprs = toexpr.(eqs) - vars = map(exprs) do expr - _vars = map(depvar -> find_thing_in_expr(expr, depvar), collect(keys(dict_depvars))) +function get_argument(eqs, v::VariableMap) + vars = map(eqs) do eq + _vars = map(depvar -> get_depvars(eq, [depvar]), v.depvar_ops) f_vars = filter(x -> !isempty(x), _vars) - map(x -> first(x), f_vars) + map(first, f_vars) end args_ = map(vars) do _vars - ind_args_ = map(var -> var.args[2:end], _vars) - syms = Set{Symbol}() - filter(vcat(ind_args_...)) do ind_arg - if ind_arg isa Symbol - if ind_arg ∈ syms + seen = [] + filter(reduce(vcat, arguments.(_vars), init = [])) do x + if x isa Number + true + else + if any(isequal(x), seen) false else - push!(syms, ind_arg) + push!(seen, x) true end - else - true end end end return args_ # TODO for all arguments end + +""" +``julia +get_variables(eqs,_indvars,_depvars) +``` + +Returns all variables that are used in each equations or boundary condition. +""" +function get_variables(eqs, v::VariableMap) + args = get_argument(eqs, v) + return map(arg -> filter(x -> !(x isa Number), arg), args) +end + +function get_number(eqs, v::VariableMap) + args = get_argument(eqs, v) + return map(arg -> filter(x -> x isa Number, arg), args) +end + +sym_op(u) = Symbol(operation(u)) diff --git a/src/training_strategies.jl b/src/training_strategies.jl index 858e93a237..a2e9dbc369 100644 --- a/src/training_strategies.jl +++ b/src/training_strategies.jl @@ -19,7 +19,7 @@ function merge_strategy_with_loglikelihood_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function; train_sets_pde = nothing, train_sets_bc = nothing) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) @@ -54,12 +54,12 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::GridTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep dx = strategy.dx eltypeθ = eltype(pinnrep.flat_init_params) train_sets = generate_training_sets(domains, dx, eqs, bcs, eltypeθ, - dict_indvars, dict_depvars) + varmap) # the points in the domain and on the boundary pde_train_sets, bcs_train_sets = train_sets @@ -112,11 +112,11 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::StochasticTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds @@ -192,11 +192,11 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuasiRandomTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds @@ -283,10 +283,10 @@ function merge_strategy_with_loss_function(pinnrep::PINNRepresentation, strategy::QuadratureTraining, datafree_pde_loss_function, datafree_bc_loss_function) - @unpack domains, eqs, bcs, dict_indvars, dict_depvars, flat_init_params = pinnrep + @unpack domains, eqs, bcs, varmap, flat_init_params = pinnrep eltypeθ = eltype(pinnrep.flat_init_params) - bounds = get_bounds(domains, eqs, bcs, eltypeθ, dict_indvars, dict_depvars, + bounds = get_bounds(domains, eqs, bcs, eltypeθ, varmap, strategy) pde_bounds, bcs_bounds = bounds