diff --git a/docs/src/examples/brusselator_reactant.jl b/docs/src/examples/brusselator_reactant.jl new file mode 100644 index 000000000..cd77e7127 --- /dev/null +++ b/docs/src/examples/brusselator_reactant.jl @@ -0,0 +1,224 @@ +using Lux, Random, Reactant, Enzyme, MLUtils, Optimisers, OnlineStats, CairoMakie, Statistics, Printf, CUDA, Revise + +const T = Float32 +global device_func = reactant_device(; force=true) + +struct PINNBrusselator{U,V} <: AbstractLuxContainerLayer{(:u, :v)} + u::U + v::V +end + +function create_mlp(act, hidden) + Chain(Dense(3 => hidden, act), Dense(hidden => hidden, act), + Dense(hidden => hidden, act), Dense(hidden => 1)) +end + +function PINNBrusselator(; hidden=128) + PINNBrusselator(create_mlp(Lux.swish, hidden), create_mlp(Lux.swish, hidden)) +end + +struct Normalizer{T} + min_vals::T + max_vals::T +end + +(n::Normalizer)(x) = (x .- n.min_vals) ./ (n.max_vals .- n.min_vals) +inv(n::Normalizer) = x -> x .* (n.max_vals .- n.min_vals) .+ n.min_vals + +function u₀(x, y) T(22.0) * (y * (1 - y))^(3/2) end +function v₀(x, y) T(27.0) * (x * (1 - x))^(3/2) end + +f(x, y, t) = (t ≥ 1.1f0 && (x - 0.3f0)^2 + (y - 0.6f0)^2 ≤ 0.01f0) ? T(5.0) : T(0.0) + +function f_batch(coords) + x, y, t = coords[1, :], coords[2, :], coords[3, :] + mask = ((x .- 0.3f0).^2 .+ (y .- 0.6f0).^2 .<= 0.01f0) .& (t .>= 1.1f0) + ifelse.(mask, T(5.0), T(0.0)) +end + +function first_derivs(net::StatefulLuxLayer, xyt) + grads = Enzyme.gradient(Enzyme.Reverse, sum ∘ net, xyt)[1] + grads[1:1, :], grads[2:2, :], grads[3:3, :] +end + +function laplacian(net::StatefulLuxLayer, xyt) + fx(x) = sum(first_derivs(net, x)[1]) + fy(x) = sum(first_derivs(net, x)[2]) + d2x = Enzyme.gradient(Enzyme.Reverse, fx, xyt)[1][1:1, :] + d2y = Enzyme.gradient(Enzyme.Reverse, fy, xyt)[1][2:2, :] + d2x .+ d2y +end + +function pde_residual(u, v, xyt, α, f_vals) + u_pred = u(xyt) + v_pred = v(xyt) + _, _, ∂u_∂t = first_derivs(u, xyt) + _, _, ∂v_∂t = first_derivs(v, xyt) + ∇²u = laplacian(u, xyt) + ∇²v = laplacian(v, xyt) + + res_u = ∂u_∂t .- (T(1.0) .+ u_pred.^2 .* v_pred .- T(4.4) .* u_pred .+ α .* ∇²u .+ f_vals) + res_v = ∂v_∂t .- (T(3.4) .* u_pred .- u_pred.^2 .* v_pred .+ α .* ∇²v) + res_u, res_v +end + +function ic_loss(u, v, xyt, target_u, target_v) + pu = u(xyt) + pv = v(xyt) + mean(abs2, pu .- target_u) + mean(abs2, pv .- target_v) +end + +function bc_loss(u, v, x0, x1, y0, y1) + ux0 = u(x0); ux1 = u(x1) + uy0 = u(y0); uy1 = u(y1) + vx0 = v(x0); vx1 = v(x1) + vy0 = v(y0); vy1 = v(y1) + mean(abs2, ux0 .- ux1) + mean(abs2, uy0 .- uy1) + + mean(abs2, vx0 .- vx1) + mean(abs2, vy0 .- vy1) +end + +function loss_fn(model, ps, st, data) + u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u) + v_net = StatefulLuxLayer{true}(model.v, ps.v, st.v) + pde_xyt, ic_data, bc_data, denorm, α = data + + actual = denorm(pde_xyt) + fvals = f_batch(actual) + + res_u, res_v = pde_residual(u_net, v_net, pde_xyt, α, fvals) + loss_pde = mean(abs2, res_u) + mean(abs2, res_v) + + ic_xyt, u_ic, v_ic = ic_data + loss_ic = ic_loss(u_net, v_net, ic_xyt, u_ic, v_ic) + + x0, x1, y0, y1 = bc_data + loss_bc = bc_loss(u_net, v_net, x0, x1, y0, y1) + + loss = loss_pde + 1000f0 * loss_ic + 100f0 * loss_bc + return loss, (; u=st.u, v=st.v), (; loss_pde, loss_ic, loss_bc) +end + +function train_brusselator!() + rng = Random.default_rng() + Random.seed!(rng, 0) + + α = T(0.001) + tspan = (0f0, 11.5f0) + xspan = (0f0, 1f0) + yspan = (0f0, 1f0) + + pde_n = 10_000; ic_n = 2000; bc_n = 2000 + + x_pde = rand(rng, T, pde_n) + y_pde = rand(rng, T, pde_n) + t_pde = rand(rng, T, pde_n) + + xyt_pde = vcat(x_pde', y_pde', t_pde') + + x_ic = rand(rng, T, ic_n) + y_ic = rand(rng, T, ic_n) + t_ic = fill(T(0.0), ic_n) + + xyt_ic = vcat(x_ic', y_ic', t_ic') + u_ic = reshape(u₀.(x_ic, y_ic), 1, :) + v_ic = reshape(v₀.(x_ic, y_ic), 1, :) + + y_bc = rand(rng, T, bc_n) + t_bc = rand(rng, T, bc_n) + x0 = vcat(fill(xspan[1], bc_n)', y_bc', t_bc') + x1 = vcat(fill(xspan[2], bc_n)', y_bc', t_bc') + x_bc = rand(rng, T, bc_n) + y0 = vcat(x_bc', fill(yspan[1], bc_n)', t_bc') + y1 = vcat(x_bc', fill(yspan[2], bc_n)', t_bc') + + mins = T.([xspan[1], yspan[1], tspan[1]]) + maxs = T.([xspan[2], yspan[2], tspan[2]]) + normalizer = Normalizer(mins, maxs) + denormalizer = inv(normalizer) + + norm = x -> normalizer(x) |> device_func + + xyt_pde = norm(xyt_pde) + xyt_ic = norm(xyt_ic) + x0 = norm(x0); x1 = norm(x1); y0 = norm(y0); y1 = norm(y1) + u_ic = device_func(u_ic) + v_ic = device_func(v_ic) + + model = PINNBrusselator() + ps, st = Lux.setup(rng, model) |> device_func + train_state = Lux.Training.TrainState(model, ps, st, Optimisers.Adam(T(0.001))) + + pde_loader = DataLoader(xyt_pde; batchsize=256, shuffle=true) + ic_loader = DataLoader((xyt_ic, u_ic, v_ic); batchsize=256, shuffle=true) + bc_loader = DataLoader((x0, x1, y0, y1); batchsize=128, shuffle=true) + + loss_trackers = ntuple(_ -> OnlineStats.CircBuff(T, 32), 4) + max_iters = 50000 + lr = i -> i < 10000 ? T(0.001) : (i < 30000 ? T(0.0001) : T(1e-5)) + + for (i, (xyt, ic, bc)) in enumerate(zip(Iterators.cycle(pde_loader), Iterators.cycle(ic_loader), Iterators.cycle(bc_loader))) + Optimisers.adjust!(train_state.optimizer_state, lr(i)) + data = (xyt, ic, bc, denormalizer, α) + + loss, st_new, stats = Lux.Training.single_train_step!(AutoEnzyme(), loss_fn, data, train_state; return_gradients=Val(false)) + train_state = Lux.Training.TrainState(train_state.model, train_state.parameters, st_new, train_state.optimizer_state) + + fit!.(loss_trackers, (T(loss), T(stats.loss_pde), T(stats.loss_ic), T(stats.loss_bc))) + + if i % 1000 == 1 || i == max_iters + m = mean ∘ OnlineStats.value + @printf "Iter: %5d Loss: %.6e PDE: %.2e IC: %.2e BC: %.2e\n" i loss m(loss_trackers[2]) m(loss_trackers[3]) m(loss_trackers[4]) + end + i ≥ max_iters && break + end + + return train_state, normalizer, denormalizer +end + +train_state, norm, denorm = train_brusselator!() + +function visualize_brusselator(train_state, normalizer, denormalizer) + xs = range(0f0, 1f0; length=50) + ys = range(0f0, 1f0; length=50) + ts = range(0f0, 11.5f0; length=40) + + grid = stack([[x, y, t] for t in ts, y in ys, x in xs]) + grid = reshape(permutedims(grid), 3, :) + + norm_grid = normalizer(grid) |> device_func + + u_net = StatefulLuxLayer{true}(train_state.model.u, cpu_device()(train_state.parameters.u), cpu_device()(train_state.states.u)) + v_net = StatefulLuxLayer{true}(train_state.model.v, cpu_device()(train_state.parameters.v), cpu_device()(train_state.states.v)) + + u_pred = u_net(norm_grid) + v_pred = v_net(norm_grid) + + u_pred = reshape(Array(u_pred), length(xs), length(ys), length(ts)) + v_pred = reshape(Array(v_pred), length(xs), length(ys), length(ts)) + + fig_u = Figure(size=(800, 600)) + ax_u = Axis(fig_u[1, 1], xlabel="x", ylabel="y", title="U") + umin, umax = extrema(u_pred) + plt_u = heatmap!(ax_u, xs, ys, u_pred[:, :, 1]; colorrange=(umin, umax)) + Colorbar(fig_u[1, 2], plt_u, label="U") + + CairoMakie.record(fig_u, "brusselator_U.gif", 1:length(ts); framerate=10) do i + plt_u[3] = u_pred[:, :, i] + ax_u.title = "U Concentration | t = $(round(ts[i], digits=2))" + end + + fig_v = Figure(size=(800, 600)) + ax_v = Axis(fig_v[1, 1], xlabel="x", ylabel="y", title="V") + vmin, vmax = extrema(v_pred) + plt_v = heatmap!(ax_v, xs, ys, v_pred[:, :, 1]; colorrange=(vmin, vmax)) + Colorbar(fig_v[1, 2], plt_v, label="V") + + CairoMakie.record(fig_v, "brusselator_V.gif", 1:length(ts); framerate=10) do i + plt_v[3] = v_pred[:, :, i] + ax_v.title = "V Concentration | t = $(round(ts[i], digits=2))" + end + + println("Saved U to brusselator_U.gif and V to brusselator_V.gif") +end + +visualize_brusselator(train_state, norm, denorm) \ No newline at end of file diff --git a/docs/src/examples/poisson_reactant.jl b/docs/src/examples/poisson_reactant.jl new file mode 100644 index 000000000..5f3e5654c --- /dev/null +++ b/docs/src/examples/poisson_reactant.jl @@ -0,0 +1,252 @@ +using Lux, Reactant, Random, Statistics, Enzyme, MLUtils, ConcreteStructs, Printf, Optimisers, CairoMakie, LinearAlgebra, OnlineStats, LossFunctions + +f_vec(xy) = -sin.(π .* xy[1, :]) .* sin.(π .* xy[2, :]) + +struct PINN_Poisson{U,P,Q} <: AbstractLuxContainerLayer{(:u, :p, :q)} + u::U + p::P + q::Q +end + +function create_mlp(act, hidden_dims, in_dims=2) + return Chain( + Dense(in_dims => hidden_dims, act), + Dense(hidden_dims => hidden_dims, act), + Dense(hidden_dims => hidden_dims, act), + Dense(hidden_dims => 1), + ) +end + +function PINN_Poisson(; hidden_dims::Int=128) + return PINN_Poisson( + create_mlp(Lux.swish, hidden_dims), + create_mlp(Lux.swish, hidden_dims), + create_mlp(Lux.swish, hidden_dims), + ) +end + +function (pinn::PINN_Poisson)(xy, ps, st) + u, st_u = Lux.apply(pinn.u, xy, ps.u, st.u) + p, st_p = Lux.apply(pinn.p, xy, ps.p, st.p) + q, st_q = Lux.apply(pinn.q, xy, ps.q, st.q) + return (u=u, p=p, q=q), merge(st, (u=st_u, p=st_p, q=st_q)) +end + +rng = Random.default_rng() +Random.seed!(rng, 0) + +pinn_poisson = PINN_Poisson() +ps_poisson, st_poisson = Lux.setup(rng, pinn_poisson) |> reactant_device() + +analytical_solution_poisson(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2) +analytical_solution_poisson(xy_batch) = analytical_solution_poisson.(xy_batch[1, :], xy_batch[2, :]) + +grid_len = 64 +grid_x = range(0.0f0, 1.0f0; length=grid_len) +grid_y = range(0.0f0, 1.0f0; length=grid_len) +xy_pde_pts = stack([[elem...] for elem in vec(collect(Iterators.product(grid_x, grid_y)))]) + +target_data_values = analytical_solution_poisson(xy_pde_pts) +target_data_values = reshape(target_data_values, 1, :) + +bc_len = 256 +x_bc_scalar = collect(range(0.0f0, 1.0f0; length=bc_len)) +y_bc_scalar = collect(range(0.0f0, 1.0f0; length=bc_len)) + +xy_bc_points = hcat( + [zeros(Float32, bc_len)'; y_bc_scalar'], + [ones(Float32, bc_len)'; y_bc_scalar'], + [x_bc_scalar'; zeros(Float32, bc_len)'], + [x_bc_scalar'; ones(Float32, bc_len)'] +) + +target_bc_values = zeros(Float32, size(xy_bc_points, 2)) +target_bc_values = reshape(target_bc_values, 1, :) + +min_target_data, max_target_data = extrema(target_data_values) +min_target_bc, max_target_bc = extrema(target_bc_values) +global_min_output_val = min(min_target_data, min_target_bc) +global_max_output_val = max(max_target_data, max_target_bc) + +min_xy_pde = minimum(xy_pde_pts) +max_xy_pde = maximum(xy_pde_pts) +xy_pde_pts_normalized = (xy_pde_pts .- min_xy_pde) ./ (max_xy_pde - min_xy_pde) + +min_xy_bc = minimum(xy_bc_points) +max_xy_bc = maximum(xy_bc_points) +xy_bc_points_normalized = (xy_bc_points .- min_xy_bc) ./ (max_xy_bc - min_xy_bc) + +target_data_values_normalized = (target_data_values .- global_min_output_val) ./ (global_max_output_val - global_min_output_val) +target_bc_values_normalized = (target_bc_values .- global_min_output_val) ./ (global_max_output_val - global_min_output_val) + +xs_plot = 0.0f0:0.02f0:1.0f0 +ys_plot = 0.0f0:0.02f0:1.0f0 +grid_plot_x = collect(xs_plot) +grid_plot_y = collect(ys_plot) + +u_real_plot = [analytical_solution_poisson(x, y) for x in xs_plot, y in ys_plot] + +fig_true = Figure() +ax_true = CairoMakie.Axis(fig_true[1, 1]; xlabel="x", ylabel="y", title="True Analytical Solution") +CairoMakie.heatmap!(ax_true, xs_plot, ys_plot, u_real_plot) +CairoMakie.contour!(ax_true, xs_plot, ys_plot, u_real_plot; levels=10, linewidth=2, color=:black) +Colorbar(fig_true[1, 2]; limits=extrema(u_real_plot), label="True u") +display(fig_true) + +@views function physics_loss_function(u_net::StatefulLuxLayer, p_net::StatefulLuxLayer, q_net::StatefulLuxLayer, xy_batch_normalized::AbstractArray) + dpdx_norm = Enzyme.gradient(Enzyme.Reverse, (x) -> sum(p_net(x)), xy_batch_normalized)[1][1:1, :] + dqdy_norm = Enzyme.gradient(Enzyme.Reverse, (x) -> sum(q_net(x)), xy_batch_normalized)[1][2:2, :] + + xy_batch_actual = xy_batch_normalized .* (max_xy_pde .- min_xy_pde) .+ min_xy_pde + f_vals_actual = f_vec(xy_batch_actual) + physics_scale_factor = 1.0f0 / (global_max_output_val - global_min_output_val) + f_vals_scaled_for_physics_loss = f_vals_actual .* physics_scale_factor + + pde_residual_component = dpdx_norm .+ dqdy_norm .- f_vals_scaled_for_physics_loss + pde_res_loss = mean(abs2, pde_residual_component) + + ∂u_∂xy_norm = Enzyme.gradient(Enzyme.Reverse, (x) -> sum(u_net(x)), xy_batch_normalized)[1] + ∂u_∂x_norm = ∂u_∂xy_norm[1:1, :] + ∂u_∂y_norm = ∂u_∂xy_norm[2:2, :] + + p_pred_norm = p_net(xy_batch_normalized) + q_pred_norm = q_net(xy_batch_normalized) + + p_consistency_residual = p_pred_norm .- ∂u_∂x_norm + q_consistency_residual = q_pred_norm .- ∂u_∂y_norm + consistency_res_loss = mean(abs2, p_consistency_residual) + mean(abs2, q_consistency_residual) + + total_physics_loss = pde_res_loss + consistency_res_loss + + mean_abs_f_val = mean(abs.(f_vals_scaled_for_physics_loss)) + mean_abs_laplacian_val = mean(abs.(dpdx_norm .+ dqdy_norm)) + + return total_physics_loss, mean_abs_f_val, mean_abs_laplacian_val +end + +function mse_loss_function(u_net::StatefulLuxLayer, target_normalized::AbstractArray, xy_normalized::AbstractArray) + u_pred_normalized = u_net(xy_normalized) + return MSELoss()(u_pred_normalized, target_normalized) +end + +function loss_function(model, ps, st, (xy_pde_normalized, target_data_normalized, xy_bc_normalized, target_bc_normalized)) + u_net = StatefulLuxLayer{true}(model.u, ps.u, st.u) + p_net = StatefulLuxLayer{true}(model.p, ps.p, st.p) + q_net = StatefulLuxLayer{true}(model.q, ps.q, st.q) + + physics_loss, mean_abs_f, mean_abs_dpdx_dqdy = physics_loss_function(u_net, p_net, q_net, xy_pde_normalized) + + data_loss = mse_loss_function(u_net, target_data_normalized, xy_pde_normalized) + bc_loss = mse_loss_function(u_net, target_bc_normalized, xy_bc_normalized) + + w_physics = 1.0f0 + w_data = 1000.0f0 + w_bc = 5000.0f0 + + total_loss = w_physics * physics_loss + w_data * data_loss + w_bc * bc_loss + + updated_st_u = u_net.st + updated_st_p = p_net.st + updated_st_q = q_net.st + updated_st_overall = merge(st, (u=updated_st_u, p=updated_st_p, q=updated_st_q)) + + return ( + total_loss, + updated_st_overall, + (; physics_loss, data_loss, bc_loss, mean_abs_f, mean_abs_dpdx_dqdy) + ) +end + +train_state = Lux.Training.TrainState(pinn_poisson, ps_poisson, st_poisson, Adam(0.001f0)) + +lr = i -> i < 10000 ? 0.001f0 : (i < 30000 ? 0.0001f0 : 0.00001f0) + +bc_dataloader = MLUtils.DataLoader( + (xy_bc_points_normalized, target_bc_values_normalized); batchsize=128, shuffle=true, partial=false +) |> reactant_device() +pde_dataloader = MLUtils.DataLoader( + (xy_pde_pts_normalized, target_data_values_normalized); batchsize=128, shuffle=true, partial=false +) |> reactant_device(); + +total_loss_tracker, physics_loss_tracker, data_loss_tracker, bc_loss_tracker, mean_abs_f_tracker, mean_abs_dpdx_dqdy_tracker = ntuple( + _ -> OnlineStats.CircBuff(Float32, 32; rev=true), 6 +) + +iter = 1 +maxiters = 50000 + +for ((xy_pde_batch, target_data_batch), (xy_bc_batch, target_bc_batch)) in + zip(Iterators.cycle(pde_dataloader), Iterators.cycle(bc_dataloader)) + + data_tuple = (xy_pde_batch, target_data_batch, xy_bc_batch, target_bc_batch) + + Optimisers.adjust!(train_state, lr(iter)) + + _, loss, stats, ts_new = Lux.Training.single_train_step!( + AutoEnzyme(), + loss_function, + data_tuple, + train_state; + return_gradients=Val(false), + ) + + train_state = ts_new + + fit!(total_loss_tracker, Float32(loss)) + fit!(physics_loss_tracker, Float32(stats.physics_loss)) + fit!(data_loss_tracker, Float32(stats.data_loss)) + fit!(bc_loss_tracker, Float32(stats.bc_loss)) + fit!(mean_abs_f_tracker, Float32(stats.mean_abs_f)) + fit!(mean_abs_dpdx_dqdy_tracker, Float32(stats.mean_abs_dpdx_dqdy)) + + mean_loss = mean(OnlineStats.value(total_loss_tracker)) + mean_physics_loss = mean(OnlineStats.value(physics_loss_tracker)) + mean_data_loss = mean(OnlineStats.value(data_loss_tracker)) + mean_bc_loss = mean(OnlineStats.value(bc_loss_tracker)) + mean_mean_abs_f = mean(OnlineStats.value(mean_abs_f_tracker)) + mean_mean_abs_dpdx_dqdy = mean(OnlineStats.value(mean_abs_dpdx_dqdy_tracker)) + + isnan(loss) && throw(ArgumentError("NaN Loss Detected")) + + if iter % 1000 == 1 || iter == maxiters + @printf "Iteration: [%6d/%6d] \t Loss: %.9f (%.9f) \t Physics Loss: %.9f (%.9f) \t Data Loss: %.9f (%.9f) \t BC Loss: %.9f (%.9f) \t |f|: %.3f (|dp/dx+dq/dy|): %.3f\n" iter maxiters loss mean_loss stats.physics_loss mean_physics_loss stats.data_loss mean_data_loss stats.bc_loss mean_bc_loss stats.mean_abs_f mean_mean_abs_dpdx_dqdy + end + + iter += 1 + iter ≥ maxiters && break +end + +cdev = cpu_device() +trained_u = StatefulLuxLayer{true}( + pinn_poisson.u, cdev(train_state.parameters.u), cdev(train_state.states.u) +) + +xs_plot_grid = collect(xs_plot) +ys_plot_grid = collect(ys_plot) +xy_plot_cpu = stack([[elem...] for elem in vec(collect(Iterators.product(xs_plot_grid, ys_plot_grid)))]) + +min_xy_plot = minimum(xy_plot_cpu) +max_xy_plot = maximum(xy_plot_cpu) +xy_plot_cpu_normalized = (xy_plot_cpu .- min_xy_plot) ./ (max_xy_plot - min_xy_plot) + +u_pred_plot_normalized = trained_u(xy_plot_cpu_normalized) +u_pred_plot_normalized = reshape(u_pred_plot_normalized, length(xs_plot), length(ys_plot)) +u_pred_plot = (u_pred_plot_normalized .* (global_max_output_val - global_min_output_val)) .+ global_min_output_val + +fig_trained = Figure() +ax_trained = CairoMakie.Axis(fig_trained[1, 1]; xlabel="x", ylabel="y", title="Trained Solution") +CairoMakie.heatmap!(ax_trained, xs_plot, ys_plot, u_pred_plot) +CairoMakie.contour!(ax_trained, xs_plot, ys_plot, u_pred_plot; levels=10, linewidth=2, color=:black) +CairoMakie.Colorbar(fig_trained[1, 2]; limits=extrema(u_pred_plot), label="Predicted u") +display(fig_trained) + +u_real_plot_final = [analytical_solution_poisson(x, y) for x in xs_plot, y in ys_plot] +abs_error_plot = abs.(u_pred_plot .- u_real_plot_final) + +fig_error = Figure() +ax_error = CairoMakie.Axis(fig_error[1, 1]; xlabel="x", ylabel="y", title="Absolute Error") +CairoMakie.heatmap!(ax_error, xs_plot, ys_plot, abs_error_plot, colorrange=(0, maximum(abs_error_plot))) +CairoMakie.Colorbar(fig_error[1, 2]; limits=(0, maximum(abs_error_plot)), label="Absolute Error") +display(fig_error) + +println("Maximum Absolute Error: ", maximum(abs_error_plot)) \ No newline at end of file diff --git a/docs/src/examples/poisson_symbolic.jl b/docs/src/examples/poisson_symbolic.jl new file mode 100644 index 000000000..f7dd22c6f --- /dev/null +++ b/docs/src/examples/poisson_symbolic.jl @@ -0,0 +1,109 @@ +using ModelingToolkit, ModelingToolkitNeuralNets, Optimization, OptimizationOptimJL, Symbolics, + Lux, NNlib, StableRNGs, Random, Plots, LineSearches, ComponentArrays + +@variables x y +Dx = Differential(x) +Dy = Differential(y) +Dxx = Dx^2 +Dyy = Dy^2 + +# Symbolic Neural Network +chain = Lux.Chain( + Lux.Dense(2, 16, σ), + Lux.Dense(16, 16, σ), + Lux.Dense(16, 1) +) +NN_expr, p_nn_sym = SymbolicNeuralNetwork(; chain=chain, n_input=2, n_output=1, rng=StableRNG(42)) +u_expr = NN_expr([x, y], p_nn_sym)[1] + + +# PDE and Boundary Conditions +f_rhs = sin(π * x) * sin(π * y) +pde_residual_expr = expand_derivatives(Dxx(u_expr) + Dyy(u_expr)) - f_rhs + +bc_x0_residual = substitute(u_expr, x => 0.0) +bc_x1_residual = substitute(u_expr, x => 1.0) +bc_y0_residual = substitute(u_expr, y => 0.0) +bc_y1_residual = substitute(u_expr, y => 1.0) + + +# Data Sampling +num_col = 200 +num_bc = 50 +Random.seed!(123) + +collocation_pts_x = rand(num_col) +collocation_pts_y = rand(num_col) + +bc_x0_y_data = rand(num_bc) +bc_x1_y_data = rand(num_bc) +bc_y0_x_data = rand(num_bc) +bc_y1_x_data = rand(num_bc) + + +@parameters col_x[1:num_col] col_y[1:num_col] +@parameters bcx0_y_sym[1:num_bc] bcx1_y_sym[1:num_bc] +@parameters bcy0_x_sym[1:num_bc] bcy1_x_sym[1:num_bc] + +bc_x0_residual^2 +substitute(bc_x0_residual^2, Dict(y =>0.0)) + +# Loss +pde_loss_expr = sum(Symbolics.fast_substitute(pde_residual_expr^2, Dict(x => col_x[i], y => col_y[i])) for i in 1:num_col) + +bc_loss_expr = sum( + Symbolics.fast_substitute(bc_x0_residual^2, Dict(y => bcx0_y_sym[i])) + + Symbolics.fast_substitute(bc_x1_residual^2, Dict(y => bcx1_y_sym[i])) + + Symbolics.fast_substitute(bc_y0_residual^2, Dict(x => bcy0_x_sym[i])) + + Symbolics.fast_substitute(bc_y1_residual^2, Dict(x => bcy1_x_sym[i])) + for i in 1:num_bc +) + +total_loss = pde_loss_expr + bc_loss_expr + +# Optimization +sym_data = vcat(col_x, col_y, bcx0_y_sym, bcx1_y_sym, bcy0_x_sym, bcy1_x_sym) +u0 = randn(Float64, length(p_nn_sym)) +rng = Random.GLOBAL_RNG +init_params = Lux.initialparameters(rng, chain) +ca = ComponentArray(init_params) + +flat_data = vcat( + collocation_pts_x, + collocation_pts_y, + bc_x0_y_data, + bc_x1_y_data, + bc_y0_x_data, + bc_y1_x_data +) + +@assert length(sym_data) == length(flat_data) + +@named system = OptimizationSystem(total_loss, p_nn_sym, [sym_data; NN_expr]) +system = complete(system) + +prob = OptimizationProblem(system, [p_nn_sym => u0], [sym_data .=> flat_data; NN_expr => ModelingToolkitNeuralNets.StatelessApplyWrapper(chain, typeof(ca))]; grad=true) +result = solve(prob, LBFGS(linesearch=LineSearches.BackTracking()); maxiters=1000) + +# Evaluation +function u_pred(xval, yval, chain, optimized_params) + ps = convert(typeof(ca),optimized_params) + return ModelingToolkitNeuralNets.stateless_apply(chain, [xval, yval], ps)[1] +end + +u_exact(x, y) = (sin(pi * x) * sin(pi * y)) / (2pi^2) + +xs = range(0, 1; length=100) +ys = range(0, 1; length=100) + +u_approx = [u_pred(x, y, chain, result.minimizer) for y in ys, x in xs] +u_truth = [u_exact(x, y) for y in ys, x in xs] +u_error = abs.(u_approx .- u_truth) + +# Plots +layout = @layout [a b c] +plt1 = heatmap(xs, ys, u_approx; title="Predicted u(x,y)", xlabel="x", ylabel="y", colorbar_title="u") +plt2 = heatmap(xs, ys, u_truth; title="Exact u(x,y)", xlabel="x", ylabel="y", colorbar_title="u") +plt3 = heatmap(xs, ys, u_error; title="|u_pred - u_exact|", xlabel="x", ylabel="y", colorbar_title="abs error") + +plot(plt1, plt2, plt3; layout=layout, size=(1100, 300))