-
Notifications
You must be signed in to change notification settings - Fork 5
make realnvp and nsf layers as part of the pkg #53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
After the recent Mooncake update ( using Random, Distributions, LinearAlgebra
using Bijectors
using Bijectors: partition, combine, PartitionMask
using Mooncake, Enzyme, ADTypes
import DifferentiationInterface as DI
# just define a MLP
function mlp3(
input_dim::Int,
hidden_dims::Int,
output_dim::Int;
activation=Flux.leakyrelu,
paramtype::Type{T} = Float64
) where {T<:AbstractFloat}
m = Chain(
Flux.Dense(input_dim, hidden_dims, activation),
Flux.Dense(hidden_dims, hidden_dims, activation),
Flux.Dense(hidden_dims, output_dim),
)
return Flux._paramtype(paramtype, m)
end
inputdim = 4
mask_idx = 1:2:inputdim
# creat a masking layer
mask = PartitionMask(inputdim, mask_idx)
cdim = length(mask_idx)
x = randn(inputdim)
t_net = mlp3(cdim, 16, cdim; paramtype = Float64)
ps, st = Optimisers.destructure(t_net)the following code runs perfectly function loss(ps, st, x, mask)
t_net = st(ps)
x₁, x₂, x₃ = partition(mask, x)
y₁ = x₁ .+ t_net(x₂)
y = combine(mask, y₁, x₂, x₃)
# println("y = ", y)
return sum(abs2, y)
end
loss(ps, st, x, mask) # return 3.0167880799441793
val, grad = DI.value_and_gradient(
ls_msk,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
ps, DI.Cache(st), DI.Constant(x), DI.Constant(mask)
)but autograd fails if I wrap struct ACL
mask::Bijectors.PartitionMask
t::Flux.Chain
end
@functor ACL (t, )
acl = ACL(mask, t_net)
psacl, stacl = Optimisers.destructure(acl)
function loss_acl(ps, st, x)
acl = st(ps)
t_net = acl.t
mask = acl.mask
x₁, x₂, x₃ = partition(mask, x)
y₁ = x₁ .+ t_net(x₂)
y = combine(mask, y₁, x₂, x₃)
return sum(abs2, y)
end
loss_acl(psacl, stacl, x) # return 3.0167880799441793
val, grad = DI.value_and_gradient(
loss_acl,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
psacl, DI.Cache(stacl), DI.Constant(x)
)with error message
val, grad = DI.value_and_gradient(
loss_acl,
ADTypes.AutoEnzyme(;
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
function_annotation=Enzyme.Const,
),
psacl, DI.Cache(stacl), DI.Constant(x)
)with output Any thoughts on this @yebai @willtebbutt? |
|
Ah looks like it only has issue when the part of the fields in the structure is annotated by struct Holder
t::Flux.Chain
end
@functor Holder
psh, sth = Optimisers.destructure(Holder(t_net))
function loss2(ps, st, x)
holder = st(ps)
t_net = holder.t
y = x .+ t_net(x)
return sum(abs2, y)
end
loss2(psh, sth, x) # return 7.408352005690478
val, grad = DI.value_and_gradient(
loss2,
ADTypes.AutoMooncake(; config = Mooncake.Config()),
psh, DI.Cache(sth), DI.Constant(x)
)with outputs |
|
@zuhengxu, can you help bisect which Mooncake version / Julia version introduced this bug? |
Good point! I'll look at this today. |
|
It appears that the remaining issues with Mooncake are minor, likely due to a lack of a specific rule. @sunxd3, can you help if it requires a new rule? |
|
I'll look into it 👍 |
Red-Portal
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi! I have some minor suggestions
|
NormalizingFlows.jl documentation for PR #53 is available at: |
|
Thank you @yebai @sunxd3 @Red-Portal again for the help and comments in the process of this PR! Let me know if this PR looks good to you and I'll merge it afterwards. |
|
Sorry for the delay. Reviewing a paper by JMLR has been taking up all my bandwidth. I'll take a deeper look tomorrow. |
Red-Portal
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only have minor suggestions. Feel free to take a look and apply them if you agree. Otherwise, looks good to me.
|
sorry for missing the tag, allow me to give a look later today or tomorrow |
sunxd3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
couple of tiny things, very happy to do another round of review
|
Thank you @sunxd3 @Red-Portal for the review! I made the corresponding updates and let me know if the current version looks good to you! |
sunxd3
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another couple of tiny things, nothing major beyond these
|
pretty much good to go from my end, but let's wait for Kyurae to take a look? |
Red-Portal
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only have a few minor comments that should be quick to handle!
|
@Red-Portal @sunxd3 Let me know if I can hit the big green button! Thanks for the quick feedback. |
|
Alright looks good to me now! |
As discussed in #36 (see #36 (comment)), I'm moving the
AffineCouplingandNeuralSplineLayerfrom the example tosrc/so it can be called.AffineCouplingandNeuralSplineLayerintosrcrealnvpand aneuralsplineflowconstructor. For therealnvp, follow the default architecture as mentioned in Advances in Black-Box VI: Normalizing Flows, Importance Weighting, and Optimization.