Skip to content

Commit e39b8a8

Browse files
committed
wip debug mooncake on coupling layers
1 parent a2f6fbe commit e39b8a8

File tree

7 files changed

+178
-7
lines changed

7 files changed

+178
-7
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "NormalizingFlows"
22
uuid = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
3-
version = "0.2.1"
3+
version = "0.2.2"
44

55
[deps]
66
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -34,4 +34,4 @@ Functors = "0.5.2"
3434
Optimisers = "0.2.16, 0.3, 0.4"
3535
ProgressMeter = "1.0.0"
3636
StatsBase = "0.33, 0.34"
37-
julia = "1.10"
37+
julia = "1.11"

example/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
55
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
6+
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
67
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
78
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
9+
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
810
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
911
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
1012
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"

example/demo_RealNVP.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ sample_per_iter = 16
4848
# callback function to log training progress
4949
cb(iter, opt_stats, re, θ) = (sample_per_iter=sample_per_iter,ad=adtype)
5050
# TODO: now using AutoMooncake the example broke, but AutoZygote works, need to debug
51-
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
51+
adtype = ADTypes.AutoMooncake(; config = nothing)
52+
# adtype = ADTypes.AutoZygote()
53+
5254
checkconv(iter, stat, re, θ, st) = stat.gradient_norm < one(T)/1000
5355
flow_trained, stats, _ = train_flow(
5456
rng,

example/test_n.jl

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
using Flux
2+
using Bijectors
3+
using Bijectors: partition, combine, PartitionMask
4+
5+
using Random, Distributions, LinearAlgebra
6+
using Functors
7+
using Optimisers, ADTypes
8+
using Mooncake, Zygote, Enzyme, ADTypes
9+
import NormalizingFlows as NF
10+
11+
import DifferentiationInterface as DI
12+
13+
14+
pt = Float64
15+
inputdim = 4
16+
outputdim = 3
17+
18+
x = randn(pt, inputdim)
19+
20+
bs = 64
21+
xs = randn(pt, inputdim, 64)
22+
23+
# compose two fully connected networks
24+
m1 = NF.fnn(inputdim, [16, 16], outputdim; output_activation=nothing, paramtype=pt)
25+
m2 = NF.fnn(outputdim, [16, 16], inputdim; output_activation=Flux.tanh, paramtype=pt)
26+
mm = reduce(, (m2, m1))
27+
psm, stm = Optimisers.destructure(mm)
28+
29+
function lsm(ps, st, x)
30+
model = st(ps)
31+
y = model(x)
32+
return sum(y) # just a dummy loss
33+
end
34+
35+
adtype = ADTypes.AutoMooncake(; config = Mooncake.Config())
36+
37+
val, grad = DI.value_and_gradient(
38+
lsm, adtype,
39+
psm, DI.Cache(stm), DI.Constant(xs)
40+
)
41+
42+
43+
acl = NF.AffineCoupling( inputdim, [16, 16], 1:2:inputdim, pt)
44+
psacl,stacl = Optimisers.destructure(acl)
45+
46+
function loss(ps, st, x)
47+
model = st(ps)
48+
y = model(x)
49+
return sum(y) # just a dummy loss
50+
end
51+
52+
val, grad = DI.value_and_gradient(
53+
loss,
54+
ADTypes.AutoEnzyme(;
55+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
56+
function_annotation=Enzyme.Const,
57+
),
58+
psacl, DI.Cache(stacl), DI.Constant(x)
59+
)
60+
61+
# val, grad = DI.value_and_gradient(
62+
# loss,
63+
# ADTypes.AutoMooncake(; config = Mooncake.Config()),
64+
# psacl, DI.Cache(stacl), DI.Constant(x)
65+
# )
66+
67+
function loss_acl_manual(ps, st, x)
68+
acl = st(ps)
69+
s_net = acl.s
70+
t_net = acl.t
71+
mask = acl.mask
72+
x₁, x₂, x₃ = partition(mask, x)
73+
y₁ = exp.(s_net(x₂)) .* x₁ .+ t_net(x₂)
74+
y = combine(mask, y₁, x₂, x₃)
75+
# println("y = ", y)
76+
return sum(y)
77+
end
78+
79+
val, grad = DI.value_and_gradient(
80+
loss_acl_manual,
81+
# ADTypes.AutoMooncake(; config = Mooncake.Config()),
82+
# ADTypes.AutoEnzyme(;
83+
# mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
84+
# function_annotation=Enzyme.Const,
85+
# ),
86+
psacl, DI.Cache(stacl), DI.Constant(x)
87+
)
88+
89+
90+
91+
function mlp3(
92+
input_dim::Int,
93+
hidden_dims::Int,
94+
output_dim::Int;
95+
activation=Flux.leakyrelu,
96+
paramtype::Type{T} = Float64
97+
) where {T<:AbstractFloat}
98+
m = Chain(
99+
Flux.Dense(input_dim, hidden_dims, activation),
100+
Flux.Dense(hidden_dims, hidden_dims, activation),
101+
Flux.Dense(hidden_dims, output_dim),
102+
)
103+
return Flux._paramtype(paramtype, m)
104+
end
105+
106+
function ls_msk(ps, st, x, mask)
107+
t_net = st(ps)
108+
x₁, x₂, x₃ = partition(mask, x)
109+
y₁ = x₁ .+ t_net(x₂)
110+
y = combine(mask, y₁, x₂, x₃)
111+
# println("y = ", y)
112+
return sum(abs2, y)
113+
end
114+
115+
inputdim = 4
116+
mask_idx = 1:2:inputdim
117+
mask = PartitionMask(inputdim, mask_idx)
118+
cdim = length(mask_idx)
119+
120+
x = randn(inputdim)
121+
122+
t_net = mlp3(cdim, 16, cdim; paramtype = Float64)
123+
ps, st = Optimisers.destructure(t_net)
124+
125+
ls_msk(ps, st, x, mask) # 3.0167880799441793
126+
127+
val, grad = DI.value_and_gradient(
128+
ls_msk,
129+
ADTypes.AutoMooncake(; config = Mooncake.Config()),
130+
ps, DI.Cache(st), DI.Constant(x), DI.Constant(mask)
131+
)
132+
133+
134+
struct ACL
135+
mask::Bijectors.PartitionMask
136+
t::Flux.Chain
137+
end
138+
@functor ACL (t, )
139+
140+
acl = ACL(mask, t_net)
141+
psacl, stacl = Optimisers.destructure(acl)
142+
143+
function loss_acl(ps, st, x)
144+
acl = st(ps)
145+
t_net = acl.t
146+
mask = acl.mask
147+
x₁, x₂, x₃ = partition(mask, x)
148+
y₁ = x₁ .+ t_net(x₂)
149+
y = combine(mask, y₁, x₂, x₃)
150+
return sum(abs2, y)
151+
end
152+
loss_acl(psacl, stacl, x) # 3.0167880799441793
153+
154+
val, grad = DI.value_and_gradient(
155+
loss_acl,
156+
ADTypes.AutoEnzyme(;
157+
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
158+
function_annotation=Enzyme.Const,
159+
),
160+
psacl, DI.Cache(stacl), DI.Constant(x)
161+
)
162+
163+
val, grad = DI.value_and_gradient(
164+
loss_acl,
165+
ADTypes.AutoMooncake(; config = Mooncake.Config()),
166+
psacl, DI.Cache(stacl), DI.Constant(x)
167+
)

src/flows/realnvp.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ function Bijectors.transform(af::AffineCoupling, x::AbstractVecOrMat)
3838
return combine(af.mask, y₁, x₂, x₃)
3939
end
4040

41-
function (af::AffineCoupling)(x::AbstractArray)
41+
function (af::AffineCoupling)(x::AbstractVecOrMat)
4242
return transform(af, x)
4343
end
4444

@@ -191,4 +191,4 @@ In *NeurIPS*.
191191
"""
192192
realnvp(q0; paramtype::Type{T} = Float64) where {T<:AbstractFloat} = realnvp(
193193
q0, [32, 32], 10; paramtype=paramtype
194-
)
194+
)

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1717
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1818

1919
[compat]
20-
Mooncake = "0.4.101"
20+
Mooncake = "0.4.140"

test/ad.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ end
8484
mode=Enzyme.set_runtime_activity(Enzyme.Reverse),
8585
function_annotation=Enzyme.Const,
8686
),
87-
ADTypes.AutoMooncake(; config=Mooncake.Config()),
87+
# ADTypes.AutoMooncake(; config=nothing),
8888
]
8989
@testset "$T" for T in [Float32, Float64]
9090
μ = 10 * ones(T, 2)

0 commit comments

Comments
 (0)