Skip to content

Commit 75c8308

Browse files
Merge pull request #1025 from ChrisRackauckas-Claude/bump-symbolics-mtk-compat
Bump Symbolics v7, SymbolicUtils v4, ModelingToolkit v11 compat
2 parents 89b4eda + 130cb8b commit 75c8308

13 files changed

+159
-65
lines changed

Project.toml

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,7 @@ LuxCore = "1.0.1"
8585
LuxLib = "1.3"
8686
MCMCChains = "7"
8787
MLDataDevices = "1.2.0"
88-
MethodOfLines = "0.11.6"
89-
ModelingToolkit = "9.46, 10"
88+
ModelingToolkit = "11"
9089
MonteCarloMeasurements = "1.1"
9190
NeuralOperators = "0.5, 0.6"
9291
Optimisers = "0.3.3, 0.4"
@@ -105,8 +104,8 @@ SciMLBase = "2.56"
105104
Statistics = "1.10"
106105
StochasticDiffEq = "6.69.1"
107106
SymbolicIndexingInterface = "0.3.31"
108-
SymbolicUtils = "3.7.2"
109-
Symbolics = "6.14"
107+
SymbolicUtils = "4.12"
108+
Symbolics = "7.8"
110109
TensorBoardLogger = "0.1.24"
111110
Test = "1.10"
112111
WeightInitializers = "1.0.3"
@@ -127,7 +126,6 @@ LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
127126
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
128127
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
129128
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
130-
MethodOfLines = "94925ecb-adb7-4558-8ed8-f975c56a0bf4"
131129
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
132130
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
133131
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
@@ -136,4 +134,4 @@ TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"
136134
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
137135

138136
[targets]
139-
test = ["Aqua", "Boltz", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "FastGaussQuadrature", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "MethodOfLines", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]
137+
test = ["Aqua", "Boltz", "CUDA", "DiffEqNoiseProcess", "ExplicitImports", "FastGaussQuadrature", "Flux", "Hwloc", "InteractiveUtils", "LineSearches", "LuxCUDA", "LuxCore", "LuxLib", "OptimizationOptimJL", "OrdinaryDiffEq", "ReTestItems", "StochasticDiffEq", "TensorBoardLogger", "Test"]

docs/src/tutorials/constraints.md

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@ with Physics-Informed Neural Networks.
2222

2323
```@example fokkerplank
2424
using NeuralPDE, Lux, ModelingToolkit, Optimization, OptimizationOptimJL, LineSearches
25-
using Integrals, Cubature
2625
using DomainSets: Interval
2726
using IntervalSets: leftendpoint, rightendpoint
2827
# the example is taken from this article https://arxiv.org/abs/1910.10503
@@ -52,15 +51,19 @@ chain = Lux.Chain(Dense(1, inn, Lux.σ),
5251
Dense(inn, inn, Lux.σ),
5352
Dense(inn, 1))
5453
55-
lb = [x_0]
56-
ub = [x_end]
54+
lb = x_0
55+
ub = x_end
56+
# Use a simple trapezoidal rule for the normalization constraint.
57+
# This avoids AD issues with Integrals.jl's C-based quadrature solvers.
58+
norm_xs = collect(range(lb, ub, length = 200))
59+
norm_dx = Float64(norm_xs[2] - norm_xs[1])
5760
function norm_loss_function(phi, θ, p)
58-
function inner_f(x, θ)
59-
0.01 * phi(x, θ) .- 1
61+
# Evaluate phi at quadrature points (each point as a 1-element vector)
62+
s = sum(1:length(norm_xs)) do i
63+
first(phi([norm_xs[i]], θ))
6064
end
61-
prob = IntegralProblem(inner_f, lb, ub, θ)
62-
norm2 = solve(prob, HCubatureJL(), reltol = 1e-8, abstol = 1e-8, maxiters = 10)
63-
abs(norm2[1])
65+
norm_val = 0.01 * s * norm_dx
66+
abs(norm_val - 1)
6467
end
6568
6669
discretization = PhysicsInformedNN(chain,

src/NeuralPDE.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ using WeightInitializers: glorot_uniform, zeros32
3636
using Zygote: Zygote
3737

3838
# Symbolic Stuff
39-
using ModelingToolkit: ModelingToolkit, PDESystem, Differential, toexpr, defaults
40-
using Symbolics: Symbolics, unwrap, arguments, operation, build_expr, Num,
39+
using ModelingToolkit: ModelingToolkit, PDESystem, Differential, toexpr
40+
using Symbolics: Symbolics, unwrap, arguments, operation, Num,
4141
expand_derivatives
4242
using SymbolicUtils: SymbolicUtils
4343
using SymbolicIndexingInterface: SymbolicIndexingInterface

src/PDE_BPINN.jl

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,19 +85,30 @@ function get_symbols(dataset, depvars, eqs)
8585
# order of pinnrep.depvars, depvar_vals, BayesianPINN.dataset must be same
8686
to_subs = Dict(depvars .=> depvar_vals)
8787

88-
numform_vars = Symbolics.get_variables.(eqs)
89-
Eq_vars = unique(reduce(vcat, numform_vars))
90-
# got equation's depvar num format {x(t)} for use in substitute()
91-
88+
# Find callable depvar terms (e.g., u(x,t)) by searching the equation tree.
89+
# In SymbolicUtils v4, get_variables may decompose u(x,t) into bare u, x, t,
90+
# so we search directly for callable terms whose operation matches a depvar name.
91+
depvar_set = Set(depvars)
9292
tobe_subs = Dict()
93-
for a in depvars
94-
for i in Eq_vars
95-
expr = toexpr(i)
96-
if (expr isa Expr) && (expr.args[1] == a)
97-
tobe_subs[a] = i
93+
function _search(term)
94+
t = SymbolicUtils.unwrap(term)
95+
return if SymbolicUtils.iscall(t)
96+
op = SymbolicUtils.operation(t)
97+
if !SymbolicUtils.iscall(op) && !(op isa Differential)
98+
name = toexpr(op)
99+
if name isa Symbol && name in depvar_set
100+
tobe_subs[name] = t
101+
end
102+
end
103+
for arg in SymbolicUtils.arguments(t)
104+
_search(arg)
98105
end
99106
end
100107
end
108+
for eq in eqs
109+
_search(eq.lhs)
110+
_search(eq.rhs)
111+
end
101112
# depvar symbolic and num format got, tobe_subs : Dict{Any, Any}(:y => y(t), :x => x(t))
102113

103114
return to_subs, tobe_subs

src/discretize.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ For more information, see `discretize` and `PINNRepresentation`.
403403
function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::AbstractPINN)
404404
(; eqs, bcs, domain) = pde_system
405405
eq_params = pde_system.ps
406-
defaults = pde_system.defaults
406+
defaults = pde_system.initial_conditions
407407
(;
408408
chain, param_estim, additional_loss, multioutput, init_params, phi,
409409
derivative, strategy, logger, iteration, self_increment,
@@ -412,7 +412,7 @@ function SciMLBase.symbolic_discretize(pde_system::PDESystem, discretization::Ab
412412
adaloss = discretization.adaptive_loss
413413

414414
default_p = eq_params isa SciMLBase.NullParameters ? nothing :
415-
[defaults[ep] for ep in eq_params]
415+
[Symbolics.value(defaults[ep]) for ep in eq_params]
416416

417417
depvars, indvars, dict_indvars,
418418
dict_depvars, dict_depvar_input = get_vars(

src/symbolic_utilities.jl

Lines changed: 30 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
using Base.Broadcast
22

3+
# build_expr was removed from Symbolics.jl v7; define locally
4+
function build_expr(head::Symbol, args)
5+
ex = Expr(head)
6+
append!(ex.args, args)
7+
return ex
8+
end
9+
310
"""
411
Override `Broadcast.__dot__` with `Broadcast.dottable(x::Function) = true`
512
@@ -155,8 +162,11 @@ function _transform_expression(
155162
derivative_variables = Symbol[]
156163
order = 0
157164
while (_args[1] isa Differential)
158-
order += 1
159-
push!(derivative_variables, toexpr(_args[1].x))
165+
d_order = _args[1].order
166+
order += d_order
167+
for _ in 1:d_order
168+
push!(derivative_variables, toexpr(_args[1].x))
169+
end
160170
_args = _args[2].args
161171
end
162172
depvar = _args[1]
@@ -197,7 +207,19 @@ function _transform_expression(
197207
integrating_var_id = [dict_indvars[i] for i in integrating_variable]
198208
else
199209
integrating_variable = toexpr(_args[1].domain.variables)
200-
integrating_var_id = [dict_indvars[integrating_variable]]
210+
# In SymbolicUtils v4, a symbolic tuple may produce
211+
# Expr(:call, :tuple, :x, :y) instead of a Julia Tuple
212+
if integrating_variable isa Expr &&
213+
integrating_variable.head == :call &&
214+
integrating_variable.args[1] === tuple
215+
integrating_variable = integrating_variable.args[2:end]
216+
integrating_var_id = [
217+
dict_indvars[i]
218+
for i in integrating_variable
219+
]
220+
else
221+
integrating_var_id = [dict_indvars[integrating_variable]]
222+
end
201223
end
202224

203225
integrating_depvars = []
@@ -336,8 +358,10 @@ Parse ModelingToolkit equation form to the inner representation.
336358
(derivative(phi2, u2, [x, y], [[ε,0]], 1, θ2) + 9 * derivative(phi1, u, [x, y], [[0,ε]], 1, θ1)) - 0]
337359
"""
338360
function parse_equation(pinnrep::PINNRepresentation, eq)
339-
eq_lhs = isequal(expand_derivatives(eq.lhs), 0) ? eq.lhs : expand_derivatives(eq.lhs)
340-
eq_rhs = isequal(expand_derivatives(eq.rhs), 0) ? eq.rhs : expand_derivatives(eq.rhs)
361+
eq_lhs = SymbolicUtils._iszero(expand_derivatives(eq.lhs)) ? eq.lhs :
362+
expand_derivatives(eq.lhs)
363+
eq_rhs = SymbolicUtils._iszero(expand_derivatives(eq.rhs)) ? eq.rhs :
364+
expand_derivatives(eq.rhs)
341365
left_expr = transform_expression(pinnrep, toexpr(eq_lhs))
342366
right_expr = transform_expression(pinnrep, toexpr(eq_rhs))
343367
left_expr = _dot_(left_expr)
@@ -379,7 +403,7 @@ function get_vars(indvars_, depvars_)
379403
depvars = Symbol[]
380404
dict_depvar_input = Dict{Symbol, Vector{Symbol}}()
381405
for d in depvars_
382-
if unwrap(d) isa SymbolicUtils.BasicSymbolic
406+
if SymbolicUtils.iscall(unwrap(d))
383407
dname = SymbolicIndexingInterface.getname(d)
384408
push!(depvars, dname)
385409
push!(

src/training_strategies.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ function get_loss_function(
422422
return solve(
423423
prob, strategy.quadrature_alg; strategy.reltol, strategy.abstol,
424424
strategy.maxiters
425-
)[1]
425+
).u
426426
end
427427
return (θ) -> f_(lb, ub, loss_function, θ) / area
428428
end

src/transform_inf_integral.jl

Lines changed: 47 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ function v_inf(t)
4343
end
4444

4545
function v_semiinf(t, a, upto_inf)
46-
if a isa Num
46+
if a isa Num || a isa SymbolicUtils.BasicSymbolic
4747
if upto_inf == true
4848
return :($t ./ (1 .- $t))
4949
else
@@ -84,6 +84,16 @@ function transform_inf_integral(
8484
)
8585
lb_ = Symbolics.tosymbol.(lb)
8686
ub_ = Symbolics.tosymbol.(ub)
87+
# Convert bounds to plain Julia numbers where possible to avoid Symbolics
88+
# arithmetic in boolean mask operations (e.g., false * Num(Inf) gives NaN)
89+
lb = map(
90+
(l, ls) -> ls === -Inf ? -Inf : ls === Inf ? Inf : (l isa Num ? SymbolicUtils.unwrap(l) : l),
91+
lb, lb_
92+
)
93+
ub = map(
94+
(u, us) -> us === Inf ? Inf : us === -Inf ? -Inf : (u isa Num ? SymbolicUtils.unwrap(u) : u),
95+
ub, ub_
96+
)
8797

8898
if -Inf in lb_ || Inf in ub_
8999
if !(integrating_variable isa Array)
@@ -118,10 +128,42 @@ function transform_inf_integral(
118128

119129
ϵ = 1 / 20 #cbrt(eps(eltypeθ))
120130

121-
lb = 0.0 .* _semiup + (-1.0 + ϵ) .* _inf + (-1.0 + ϵ) .* _semilw + _none .* lb +
122-
lb ./ (1 .+ lb) .* _num_semiup + (-1.0 + ϵ) .* _num_semilw
123-
ub = (1.0 - ϵ) .* _semiup + (1.0 - ϵ) .* _inf + 0.0 .* _semilw + _none .* ub +
124-
(1.0 - ϵ) .* _num_semiup + ub ./ (1 .+ ub) .* _num_semilw
131+
# Use conditional logic instead of boolean mask multiplication to avoid
132+
# NaN from 0.0 * Inf (IEEE754 behavior) with symbolic infinity bounds
133+
lb = map(eachindex(lb)) do i
134+
if _none[i]
135+
lb[i]
136+
elseif _inf[i]
137+
-1.0 + ϵ
138+
elseif _semiup[i]
139+
0.0
140+
elseif _semilw[i]
141+
-1.0 + ϵ
142+
elseif _num_semiup[i]
143+
lb[i] / (1 + lb[i])
144+
elseif _num_semilw[i]
145+
-1.0 + ϵ
146+
else
147+
lb[i]
148+
end
149+
end
150+
ub = map(eachindex(ub)) do i
151+
if _none[i]
152+
ub[i]
153+
elseif _inf[i]
154+
1.0 - ϵ
155+
elseif _semiup[i]
156+
1.0 - ϵ
157+
elseif _semilw[i]
158+
0.0
159+
elseif _num_semiup[i]
160+
1.0 - ϵ
161+
elseif _num_semilw[i]
162+
ub[i] / (1 + ub[i])
163+
else
164+
ub[i]
165+
end
166+
end
125167

126168
j = get_inf_transformation_jacobian(
127169
integrating_var_transformation, _inf, _semiup,

test/BPINN_PDE_tests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ end
262262
[t],
263263
[u(t)],
264264
[p],
265-
defaults = Dict([p => 4.0])
265+
initial_conditions = Dict([p => 4.0])
266266
)
267267

268268
analytic_sol_func1(u0, t) = u0 + sinpi(2t) / (2π)
@@ -355,7 +355,7 @@ end
355355

356356
@named pde_system = PDESystem(
357357
eqs, bcs, domains,
358-
[t], [x(t), y(t), z(t)], [σ_], defaults = Dict([p => 1.0 for p in [σ_]])
358+
[t], [x(t), y(t), z(t)], [σ_], initial_conditions = Dict([p => 1.0 for p in [σ_]])
359359
)
360360

361361
sol1 = ahmc_bayesian_pinn_pde(
@@ -491,7 +491,7 @@ end
491491
[x, t],
492492
[u(x, t)],
493493
[α],
494-
defaults = Dict([α => 2.0])
494+
initial_conditions = Dict([α => 2.0])
495495
)
496496

497497
# neccesarry for loss function construction (involves Operator masking)

test/NNPDE_tests.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,8 @@ end
170170
@testitem "PDE II: 2D Poisson" tags = [:nnpde1] setup = [NNPDE1TestSetup] begin
171171
using Lux, Random, Optimisers, DomainSets, Cubature, QuasiMonteCarlo, Integrals
172172
import DomainSets: Interval, infimum, supremum
173+
using OptimizationOptimJL: BFGS
174+
using LineSearches: BackTracking
173175

174176
function test_2d_poisson_equation(chain, strategy)
175177
@parameters x y
@@ -196,7 +198,9 @@ end
196198
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)
197199
@named pde_system = PDESystem(eq, bcs, domains, [x, y], [u(x, y)])
198200
prob = discretize(pde_system, discretization)
199-
res = solve(prob, Adam(0.1); maxiters = 500, callback)
201+
res = solve(prob, Adam(0.01); maxiters = 1000, callback)
202+
prob = remake(prob, u0 = res.u)
203+
res = solve(prob, BFGS(linesearch = BackTracking()); maxiters = 1000)
200204
phi = discretization.phi
201205

202206
xs, ys = [infimum(d.domain):0.01:supremum(d.domain) for d in domains]
@@ -404,9 +408,9 @@ end
404408
end
405409

406410
# Adam warmup for robustness, then BFGS for convergence
407-
res = solve(prob, Adam(0.01); maxiters = 500)
411+
res = solve(prob, Adam(0.01); maxiters = 1000)
408412
prob = remake(prob, u0 = res.u)
409-
res = solve(prob, BFGS(linesearch = BackTracking()); maxiters = 500)
413+
res = solve(prob, BFGS(linesearch = BackTracking()); maxiters = 1000)
410414

411415
dx = 0.1
412416
xs, ts = [infimum(d.domain):dx:supremum(d.domain) for d in domains]

0 commit comments

Comments
 (0)