2
2
# training loop for variational objectives
3
3
# ######################################################
4
4
function pm_next! (pm, stats:: NamedTuple )
5
- return ProgressMeter. next! (pm; showvalues= [ tuple (s ... ) for s in pairs (stats)] )
5
+ return ProgressMeter. next! (pm; showvalues= map ( tuple, keys (stats), values (stats)) )
6
6
end
7
7
8
- _wrap_in_DI_context (args) = DifferentiationInterface. Constant .([ args... ])
8
+ _wrap_in_DI_context (args) = map ( DifferentiationInterface. Constant, args)
9
9
10
10
function _prepare_gradient (loss, adbackend, θ, args... )
11
11
if isempty (args)
12
12
return DifferentiationInterface. prepare_gradient (loss, adbackend, θ)
13
13
end
14
- return DifferentiationInterface. prepare_gradient (loss, adbackend, θ, _wrap_in_DI_context ( args)... )
14
+ return DifferentiationInterface. prepare_gradient (loss, adbackend, θ, map (DifferentiationInterface . Constant, args)... )
15
15
end
16
16
17
17
function _value_and_gradient (loss, prep, adbackend, θ, args... )
18
18
if isempty (args)
19
19
return DifferentiationInterface. value_and_gradient (loss, prep, adbackend, θ)
20
20
end
21
- return DifferentiationInterface. value_and_gradient (loss, prep, adbackend, θ, _wrap_in_DI_context ( args)... )
21
+ return DifferentiationInterface. value_and_gradient (loss, prep, adbackend, θ, map (DifferentiationInterface . Constant, args)... )
22
22
end
23
23
24
24
@@ -42,7 +42,6 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal
42
42
- `re`: reconstruction function that maps the flattened parameters to the normalizing flow
43
43
- `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant)
44
44
45
-
46
45
# Keyword Arguments
47
46
- `max_iters::Int=10000`: maximum number of iterations
48
47
- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps
@@ -102,9 +101,9 @@ function optimize(
102
101
stat = (iteration= i, loss= ls, gradient_norm= norm (g))
103
102
104
103
# callback
105
- if ! isnothing ( callback)
104
+ if callback != = nothing
106
105
new_stat = callback (i, opt_stats, reconstruct, θ)
107
- stat = ! isnothing ( new_stat) ? merge (stat, new_stat) : stat
106
+ stat = new_stat != = nothing ? merge (stat, new_stat) : stat
108
107
end
109
108
push! (opt_stats, stat)
110
109
0 commit comments