Skip to content

Commit 0d6302b

Browse files
sunxd3devmotion
andauthored
Apply suggestions from code review
Co-authored-by: David Widmann <[email protected]>
1 parent 906d788 commit 0d6302b

File tree

2 files changed

+9
-9
lines changed

2 files changed

+9
-9
lines changed

src/NormalizingFlows.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ using DocStringExtensions
1111

1212
export train_flow, elbo, loglikelihood
1313

14-
""" train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
14+
"""
15+
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
1516
1617
Train the given normalizing flow `flow` by calling `optimize`.
1718
@@ -61,7 +62,7 @@ function train_flow(
6162
ADbackend,
6263
loss,
6364
θ_flat,
64-
re,
65+
re,
6566
(rng, args...)...;
6667
max_iters=max_iters,
6768
optimiser=optimiser,

src/optimize.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,23 @@
22
# training loop for variational objectives
33
#######################################################
44
function pm_next!(pm, stats::NamedTuple)
5-
return ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
5+
return ProgressMeter.next!(pm; showvalues=map(tuple, keys(stats), values(stats)))
66
end
77

8-
_wrap_in_DI_context(args) = DifferentiationInterface.Constant.([args...])
8+
_wrap_in_DI_context(args) = map(DifferentiationInterface.Constant, args)
99

1010
function _prepare_gradient(loss, adbackend, θ, args...)
1111
if isempty(args)
1212
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ)
1313
end
14-
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, _wrap_in_DI_context(args)...)
14+
return DifferentiationInterface.prepare_gradient(loss, adbackend, θ, map(DifferentiationInterface.Constant, args)...)
1515
end
1616

1717
function _value_and_gradient(loss, prep, adbackend, θ, args...)
1818
if isempty(args)
1919
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ)
2020
end
21-
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, _wrap_in_DI_context(args)...)
21+
return DifferentiationInterface.value_and_gradient(loss, prep, adbackend, θ, map(DifferentiationInterface.Constant, args)...)
2222
end
2323

2424

@@ -42,7 +42,6 @@ Iteratively updating the parameters `θ` of the normalizing flow `re(θ)` by cal
4242
- `re`: reconstruction function that maps the flattened parameters to the normalizing flow
4343
- `args...`: additional arguments for `loss` (will be set as DifferentiationInterface.Constant)
4444
45-
4645
# Keyword Arguments
4746
- `max_iters::Int=10000`: maximum number of iterations
4847
- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps
@@ -102,9 +101,9 @@ function optimize(
102101
stat = (iteration=i, loss=ls, gradient_norm=norm(g))
103102

104103
# callback
105-
if !isnothing(callback)
104+
if callback !== nothing
106105
new_stat = callback(i, opt_stats, reconstruct, θ)
107-
stat = !isnothing(new_stat) ? merge(stat, new_stat) : stat
106+
stat = new_stat !== nothing ? merge(stat, new_stat) : stat
108107
end
109108
push!(opt_stats, stat)
110109

0 commit comments

Comments
 (0)