Skip to content

Conversation

zuhengxu
Copy link
Member

@zuhengxu zuhengxu commented Jul 5, 2025

As discussed in #36 (see #36 (comment)), I'm moving the AffineCoupling and NeuralSplineLayer from the example to src/ so it can be called.

@zuhengxu
Copy link
Member Author

zuhengxu commented Jul 23, 2025

After the recent Mooncake update (Mooncake v0.4.140, DI v0.7.3, julia v1.11.6), all realnvp and neural spline flow tests/demos failed. After spending some time debugging it, it seems like its due to the failure of Mooncake traversing through nested structures? Here is an example:

using Random, Distributions, LinearAlgebra
using Bijectors
using Bijectors: partition, combine, PartitionMask
using Mooncake, Enzyme, ADTypes
import DifferentiationInterface as DI

# just define a MLP 
function mlp3(
    input_dim::Int, 
    hidden_dims::Int, 
    output_dim::Int; 
    activation=Flux.leakyrelu,
    paramtype::Type{T} = Float64
) where {T<:AbstractFloat}
    m = Chain(
        Flux.Dense(input_dim, hidden_dims, activation),
        Flux.Dense(hidden_dims, hidden_dims, activation),
        Flux.Dense(hidden_dims, output_dim),
    )
    return Flux._paramtype(paramtype, m)
end


inputdim = 4
mask_idx = 1:2:inputdim
# creat a masking layer
mask = PartitionMask(inputdim, mask_idx)
cdim = length(mask_idx)

x = randn(inputdim)

t_net = mlp3(cdim, 16, cdim; paramtype = Float64)
ps, st = Optimisers.destructure(t_net)

the following code runs perfectly

function loss(ps, st, x, mask)
    t_net = st(ps)
    x₁, x₂, x₃ = partition(mask, x)
    y₁ = x₁ .+ t_net(x₂)
    y = combine(mask, y₁, x₂, x₃)
    # println("y = ", y)
    return sum(abs2, y)
end


loss(ps, st, x, mask)  # return 3.0167880799441793

val, grad = DI.value_and_gradient(
    ls_msk, 
    ADTypes.AutoMooncake(; config = Mooncake.Config()), 
    ps, DI.Cache(st), DI.Constant(x), DI.Constant(mask)
)

but autograd fails if I wrap t_net and mask into a struct

struct ACL
    mask::Bijectors.PartitionMask
    t::Flux.Chain
end
@functor ACL (t, )

acl = ACL(mask, t_net)
psacl, stacl = Optimisers.destructure(acl)

function loss_acl(ps, st, x)
    acl = st(ps)
    t_net = acl.t
    mask = acl.mask
    x₁, x₂, x₃ = partition(mask, x)
    y₁ = x₁ .+ t_net(x₂)
    y = combine(mask, y₁, x₂, x₃)
    return sum(abs2, y)
end
loss_acl(psacl, stacl, x) # return 3.0167880799441793

val, grad = DI.value_and_gradient(
    loss_acl, 
    ADTypes.AutoMooncake(; config = Mooncake.Config()), 
    psacl, DI.Cache(stacl), DI.Constant(x)
)

with error message

Unreachable reached at 0x7597f9780297

[1530713] signal 4 (2): Illegal instruction
in expression starting at REPL[47]:1
RRuleZeroWrapper at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
|> at ./operators.jl:926 [inlined]
opaque closure at ./<missing>:0
DerivedRule at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:966
unknown function (ip: 0x7597f9780f51)
DynamicDerivedRule at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:1739
RRuleZeroWrapper at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:302
unknown function (ip: 0x7597f977b8ba)
_Trainable_biwalk at /home/zuhdav/.julia/packages/Optimisers/W5seC/src/destructure.jl:108 [inlined]
opaque closure at ./<missing>:0
DerivedRule at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:966 [inlined]
_build_rule! at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:1827
LazyDerivedRule at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:1822 [inlined]
RRuleZeroWrapper at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:302 [inlined]
loss_acl at ./REPL[43]:2 [inlined]
opaque closure at ./<missing>:0
DerivedRule at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interpreter/s2s_reverse_mode_ad.jl:966
unknown function (ip: 0x7597fa9be0df)
jl_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/builtins.c:831
#prepare_gradient_cache#781 at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interface.jl:509
prepare_gradient_cache at /home/zuhdav/.julia/packages/Mooncake/qKNoG/src/interface.jl:506 [inlined]
prepare_gradient_nokwarg at /home/zuhdav/.julia/packages/DifferentiationInterface/a7NWj/ext/DifferentiationInterfaceMooncakeExt/onearg.jl:114
value_and_gradient at /home/zuhdav/.julia/packages/DifferentiationInterface/a7NWj/src/first_order/gradient.jl:36
unknown function (ip: 0x7597fa9ade93)
jl_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
do_call at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:126
eval_value at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:223
eval_stmt_value at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:174 [inlined]
eval_body at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:666
jl_interpret_toplevel_thunk at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/interpreter.c:824
jl_toplevel_eval_flex at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:943
jl_toplevel_eval_flex at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
jl_toplevel_eval_flex at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:886
ijl_toplevel_eval_in at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/toplevel.c:994
eval at ./boot.jl:430 [inlined]
eval_user_input at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:261
repl_backend_loop at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:368
#start_repl_backend#59 at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:343
start_repl_backend at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:340
#run_repl#76 at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:500
run_repl at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/usr/share/julia/stdlib/v1.11/REPL/src/REPL.jl:486
jfptr_run_repl_10230.1 at /home/zuhdav/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/compiled/v1.11/REPL/u0gqU_QBeOa.so (unknown line)
#1150 at ./client.jl:446
jfptr_YY.1150_14947.1 at /home/zuhdav/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/share/julia/compiled/v1.11/REPL/u0gqU_QBeOa.so (unknown line)
jl_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
jl_f__call_latest at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/builtins.c:875
#invokelatest#2 at ./essentials.jl:1055 [inlined]
invokelatest at ./essentials.jl:1052 [inlined]
run_main_repl at ./client.jl:430
repl_main at ./client.jl:567 [inlined]
_start at ./client.jl:541
jfptr__start_73597.1 at /home/zuhdav/.julia/juliaup/julia-1.11.6+0.x64.linux.gnu/lib/julia/sys.so (unknown line)
jl_apply at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/julia.h:2157 [inlined]
true_main at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/jlapi.c:900
jl_repl_entrypoint at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/src/jlapi.c:1059
main at /cache/build/tester-amdci4-12/julialang/julia-release-1-dot-11/cli/loader_exe.c:58
unknown function (ip: 0x759918229d8f)
__libc_start_main at /lib/x86_64-linux-gnu/libc.so.6 (unknown line)
unknown function (ip: 0x4010b8)
Allocations: 435196342 (Pool: 435185964; Big: 10378); GC: 172
[1]    1530713 illegal hardware instruction (core dumped)  julia --project=@. --threads=20

Enzyme works on this example

val, grad = DI.value_and_gradient(
    loss_acl, 
    ADTypes.AutoEnzyme(;
            mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
            function_annotation=Enzyme.Const,
        ),
    psacl, DI.Cache(stacl), DI.Constant(x)
)

with output

(5.453049189044009, [0.00547994735278205, -0.0026771692066405567, -0.005708182026795597, -0.7924439468724516, -0.7544730762364316, -0.16686686701374107, 0.4478351144536696, -0.006762348018489824, 0.004773511138093715, -0.0005490333340741215  …  0.0020767492481565375, 0.00012238617695840655, 0.0028390198598604306, 0.0001673080114345703, -0.28502936510323545, -0.016797239410021175, -0.5721215941140413, -0.0337160467115233, -4.572356477535249, -0.2694563291515709])

Any thoughts on this @yebai @willtebbutt?

@zuhengxu
Copy link
Member Author

zuhengxu commented Jul 23, 2025

Ah looks like it only has issue when the part of the fields in the structure is annotated by @functor; the following example when the struct only contains a single trainable works:

struct Holder
    t::Flux.Chain
end
@functor Holder 

psh, sth = Optimisers.destructure(Holder(t_net))

function loss2(ps, st, x)
    holder = st(ps)
    t_net = holder.t
    y = x .+ t_net(x)
    return sum(abs2, y)
end
loss2(psh, sth, x)  # return 7.408352005690478

val, grad = DI.value_and_gradient(
    loss2, 
    ADTypes.AutoMooncake(; config = Mooncake.Config()), 
    psh, DI.Cache(sth), DI.Constant(x)
)

with outputs

(7.408352005690478, [0.016964275123278565, -1.6770328865228394, 0.012168583592473805, -1.899494171602285, -0.03156295180782276, 1.1263707232774889, 0.041515954305201684, -1.633822068384897, 0.0291423962210742, -0.06883290364310875  …  -0.5149587512518972, 0.3813047278162709, 0.03035515401339805, 0.004346607224947255, 0.004533179126481482, -0.0033566234747215667, -5.299778474492951, -0.7588844845815451, -0.7914585162356058, 0.5860408690548988])

@yebai
Copy link
Member

yebai commented Jul 24, 2025

@zuhengxu, can you help bisect which Mooncake version / Julia version introduced this bug?

@zuhengxu
Copy link
Member Author

zuhengxu commented Jul 24, 2025

@zuhengxu, can you help bisect which Mooncake version / Julia version introduced this bug?

Good point! I'll look at this today.

@zuhengxu
Copy link
Member Author

@zuhengxu, can you help bisect which Mooncake version / Julia version introduced this bug?

@yebai I finally nailed down the versions. I believe the bug was introduced by Mooncake v0.4.124---tests all pass under julia v1.11.6 + Mooncake v0.4.123 but failed under julia v1.11.6 + Mooncake v0.4.124.

@yebai
Copy link
Member

yebai commented Jul 28, 2025

It appears that the remaining issues with Mooncake are minor, likely due to a lack of a specific rule.

@sunxd3, can you help if it requires a new rule?

@sunxd3
Copy link
Member

sunxd3 commented Jul 28, 2025

I'll look into it 👍

Copy link
Member

@Red-Portal Red-Portal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi! I have some minor suggestions

Copy link
Contributor

github-actions bot commented Aug 8, 2025

NormalizingFlows.jl documentation for PR #53 is available at:
https://TuringLang.github.io/NormalizingFlows.jl/previews/PR53/

@zuhengxu zuhengxu requested review from Red-Portal, yebai and sunxd3 August 9, 2025 05:43
@zuhengxu
Copy link
Member Author

Thank you @yebai @sunxd3 @Red-Portal again for the help and comments in the process of this PR! Let me know if this PR looks good to you and I'll merge it afterwards.

@Red-Portal
Copy link
Member

Sorry for the delay. Reviewing a paper by JMLR has been taking up all my bandwidth. I'll take a deeper look tomorrow.

Copy link
Member

@Red-Portal Red-Portal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only have minor suggestions. Feel free to take a look and apply them if you agree. Otherwise, looks good to me.

@sunxd3
Copy link
Member

sunxd3 commented Aug 18, 2025

sorry for missing the tag, allow me to give a look later today or tomorrow

Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

couple of tiny things, very happy to do another round of review

@zuhengxu zuhengxu requested review from sunxd3 and Red-Portal August 20, 2025 07:00
@zuhengxu
Copy link
Member Author

Thank you @sunxd3 @Red-Portal for the review! I made the corresponding updates and let me know if the current version looks good to you!

Copy link
Member

@sunxd3 sunxd3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

another couple of tiny things, nothing major beyond these

@sunxd3
Copy link
Member

sunxd3 commented Aug 20, 2025

pretty much good to go from my end, but let's wait for Kyurae to take a look?

Copy link
Member

@Red-Portal Red-Portal left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I only have a few minor comments that should be quick to handle!

@zuhengxu
Copy link
Member Author

@Red-Portal @sunxd3 Let me know if I can hit the big green button! Thanks for the quick feedback.

@Red-Portal
Copy link
Member

Alright looks good to me now!

@zuhengxu zuhengxu merged commit 3504009 into main Aug 20, 2025
5 checks passed
@zuhengxu zuhengxu deleted the flow_constructor branch August 20, 2025 21:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants