-
-
Notifications
You must be signed in to change notification settings - Fork 226
Inverse Dirichlet Adaptive Loss #504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
25f58fc
5c95d34
7c3a7a9
0597829
cec87ef
165e3d0
d19fd89
dabc1d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -298,6 +298,39 @@ SciMLBase.@add_kwonly function GradientScaleAdaptiveLoss(reweight_every; weight_ | |
| pde_loss_weights=pde_loss_weights, bc_loss_weights=bc_loss_weights, additional_loss_weights=additional_loss_weights) | ||
| end | ||
|
|
||
| """ | ||
| A way of adaptively reweighting the components of the loss function in the total sum such that BC_i loss weights are based on the gradient variance. In particular, the weights are chosen so that the | ||
| variances over the components of the back-propagated weighted gradients are equal across all objectives. | ||
|
|
||
| * `reweight_every`: how often to reweight the BC loss functions, measured in iterations. reweighting is somewhat expensive since it involves evaluating the gradient of each component loss function, | ||
| * `weight_change_inertia`: a real number that represents the inertia of the exponential moving average of the BC weight changes, | ||
| * `pde_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of PDE equations, which describes the weight the respective PDE loss has in the full loss sum, | ||
| * `bc_loss_weights`: either a scalar (which will be broadcast) or vector the size of the number of BC equations, which describes the initial weight the respective BC loss has in the full loss sum, | ||
| * `additional_loss_weights`: a scalar which describes the weight the additional loss function has in the full loss sum, this is currently not adaptive and will be constant with this adaptive loss, | ||
|
|
||
| from paper | ||
| Inverse Dirichlet weighting enables reliable training of physics informed neural networks | ||
| Suryanarayana Maddu, Dominik Sturm, Christian L Müller, and Ivo F Sbalzarini | ||
| https://iopscience.iop.org/article/10.1088/2632-2153/ac3712/pdf | ||
| with code reference | ||
| https://github.com/mosaic-group/inverse-dirichlet-pinn | ||
| """ | ||
|
|
||
| mutable struct InverseDirichletAdaptiveLoss{T <: Real} <: AbstractAdaptiveLoss | ||
| reweight_every::Int64 | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Rename this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I chose There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| weight_change_inertia::T | ||
| pde_loss_weights::Vector{T} | ||
| bc_loss_weights::Vector{T} | ||
| additional_loss_weights::Vector{T} | ||
| SciMLBase.@add_kwonly function InverseDirichletAdaptiveLoss{T}(reweight_every; weight_change_inertia=0.9, pde_loss_weights=1, bc_loss_weights=1, additional_loss_weights=1) where T <: Real | ||
hpieper14 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| new(convert(Int64, reweight_every), convert(T, weight_change_inertia), vectorify(pde_loss_weights, T), vectorify(bc_loss_weights, T), vectorify(additional_loss_weights, T)) | ||
| end | ||
| end | ||
| # default to Float64 | ||
| SciMLBase.@add_kwonly function InverseDirichletAdaptiveLoss(reweight_every; weight_change_inertia=0.9, pde_loss_weights=1, bc_loss_weights=1, additional_loss_weights=1) | ||
| InverseDirichletAdaptiveLoss{Float64}(reweight_every; weight_change_inertia=weight_change_inertia, | ||
| pde_loss_weights=pde_loss_weights, bc_loss_weights=bc_loss_weights, additional_loss_weights=additional_loss_weights) | ||
| end | ||
|
|
||
| """ | ||
| A way of adaptively reweighting the components of the loss function in the total sum such that the loss weights are maximized by an internal optimiser, which leads to a behavior where loss functions that have not been satisfied get a greater weight, | ||
|
|
@@ -1406,6 +1439,25 @@ function discretize_inner_functions(pde_system::PDESystem, discretization::Physi | |
| end | ||
| nothing | ||
| end | ||
| if adaloss isa InverseDirichletAdaptiveLoss | ||
| weight_change_inertia = discretization.adaptive_loss.weight_change_inertia | ||
| function run_loss_inverse_dirichlet_adaptive_loss(0) | ||
| if iteration[1] % adaloss.reweight_every == 0 | ||
| pde_grads_std_all = [std(Zygote.gradient(pde_loss_function, 0)[1]) for pde_loss_function in pde_loss_function] | ||
hpieper14 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| pde_grads_std_max = maximum(pde_grads_std_all) | ||
| bc_grads_std = [std(Zygote.gradient(bc_loss_function, 0)[1]) for bc_loss_function in bc_loss_funcitons] | ||
hpieper14 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| nonzero_divisor_eps = adaloss_T isa Float64 ? Float64(1e-11) : convert(adaloss_T, 1e-7) | ||
|
||
| bc_loss_weights_proposed = pde_grad_std_max ./ (bc_grads_std .+ nonzero_divisor_eps) | ||
| adaloss.bc_loss_weights .= weight_change_intertia .* adaloss.bc_loss_weights .+ (1 .- weight_change_inertia) .* bc_loss_weights_proposed | ||
|
|
||
| logscalar(logger, pde_grads_std_max, "adaptive_loss/pde_grad_std_max", iteration[1]) | ||
| logvector(logger, pde_grads_std_all, "adaptive_loss/pde_grad_std_all", iteration[1]) | ||
| logvector(logger, bc_grads_std, "adaptive_loss/bc_grad_std", iteration[1]) | ||
| logvector(logger, adaloss.bc_loss_weights, "adaptive_loss/bc_loss_weights", iteration[1]) | ||
| end | ||
| nothing | ||
| end | ||
| elseif adaloss isa MiniMaxAdaptiveLoss | ||
| pde_max_optimiser = adaloss.pde_max_optimiser | ||
| bc_max_optimiser = adaloss.bc_max_optimiser | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.