Skip to content

Commit dbe841f

Browse files
authored
Big Refactor (#492)
* Big Refactor * replace some lux functions * fix maybe
1 parent 112e843 commit dbe841f

22 files changed

+248
-215
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
2929
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
3030
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
3131
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
32+
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"
3233
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3334

3435
[compat]
@@ -57,5 +58,6 @@ SciMLBase = "2"
5758
SciMLSensitivity = "7"
5859
ScientificTypesBase = "3"
5960
Statistics = "1"
61+
WeightInitializers = "1"
6062
Zygote = "0.7"
6163
julia = "1.10"

benchmark/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
55
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
6+
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
67
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
78
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
89
OrdinaryDiffEqDefault = "50262376-6c5a-4cf5-baba-aaf4f84d72d7"
@@ -16,6 +17,7 @@ ADTypes = "1"
1617
BenchmarkTools = "1"
1718
ComponentArrays = "0.15"
1819
DifferentiationInterface = "0.7"
20+
Distributions = "0.25"
1921
Lux = "1"
2022
LuxCore = "1"
2123
OrdinaryDiffEqDefault = "1"

benchmark/benchmarks.jl

Lines changed: 66 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ import ADTypes,
22
BenchmarkTools,
33
ComponentArrays,
44
DifferentiationInterface,
5+
Distributions,
56
Lux,
67
LuxCore,
78
OrdinaryDiffEqDefault,
@@ -11,25 +12,18 @@ import ADTypes,
1112
Zygote,
1213
ContinuousNormalizingFlows
1314

14-
SUITE = BenchmarkTools.BenchmarkGroup()
15-
16-
SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"])
17-
18-
SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"])
19-
SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"])
20-
21-
SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
22-
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])
23-
24-
SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
25-
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])
26-
2715
rng = StableRNGs.StableRNG(1)
28-
nvars = 2^3
16+
ndata = 2^10
17+
ndimension = 1
18+
data_dist = Distributions.Beta{Float32}(2.0f0, 4.0f0)
19+
r = rand(rng, data_dist, ndimension, ndata)
20+
r = convert.(Float32, r)
21+
22+
nvars = size(r, 1)
2923
naugs = nvars
3024
n_in = nvars + naugs
31-
n = 2^6
32-
nn = Lux.Chain(Lux.Dense(n_in => n_in, tanh))
25+
26+
nn = Lux.Chain(Lux.Dense(n_in => 3 * n_in, tanh), Lux.Dense(3 * n_in => n_in, tanh))
3327

3428
icnf = ContinuousNormalizingFlows.construct(
3529
ContinuousNormalizingFlows.ICNF,
@@ -49,9 +43,29 @@ icnf = ContinuousNormalizingFlows.construct(
4943
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
5044
),
5145
)
46+
47+
icnf2 = ContinuousNormalizingFlows.construct(
48+
ContinuousNormalizingFlows.ICNF,
49+
nn,
50+
nvars,
51+
naugs;
52+
inplace = true,
53+
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
54+
tspan = (0.0f0, 1.0f0),
55+
steer_rate = 1.0f-1,
56+
λ₁ = 1.0f-2,
57+
λ₂ = 1.0f-2,
58+
λ₃ = 1.0f-2,
59+
rng,
60+
sol_kwargs = (;
61+
save_everystep = false,
62+
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
63+
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
64+
),
65+
)
66+
5267
ps, st = LuxCore.setup(icnf.rng, icnf)
5368
ps = ComponentArrays.ComponentArray(ps)
54-
r = rand(icnf.rng, Float32, nvars, n)
5569

5670
function diff_loss_tn(x::Any)
5771
return ContinuousNormalizingFlows.loss(
@@ -72,49 +86,6 @@ function diff_loss_tt(x::Any)
7286
)
7387
end
7488

75-
diff_loss_tn(ps)
76-
diff_loss_tt(ps)
77-
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
78-
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)
79-
GC.gc()
80-
81-
SUITE["main"]["no_inplace"]["direct"]["train"] =
82-
BenchmarkTools.@benchmarkable diff_loss_tn(ps)
83-
SUITE["main"]["no_inplace"]["direct"]["test"] =
84-
BenchmarkTools.@benchmarkable diff_loss_tt(ps)
85-
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
86-
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
87-
diff_loss_tn,
88-
ADTypes.AutoZygote(),
89-
ps,
90-
)
91-
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
92-
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
93-
diff_loss_tt,
94-
ADTypes.AutoZygote(),
95-
ps,
96-
)
97-
98-
icnf2 = ContinuousNormalizingFlows.construct(
99-
ContinuousNormalizingFlows.ICNF,
100-
nn,
101-
nvars,
102-
naugs;
103-
inplace = true,
104-
compute_mode = ContinuousNormalizingFlows.LuxVecJacMatrixMode(ADTypes.AutoZygote()),
105-
tspan = (0.0f0, 1.0f0),
106-
steer_rate = 1.0f-1,
107-
λ₁ = 1.0f-2,
108-
λ₂ = 1.0f-2,
109-
λ₃ = 1.0f-2,
110-
rng,
111-
sol_kwargs = (;
112-
save_everystep = false,
113-
alg = OrdinaryDiffEqDefault.DefaultODEAlgorithm(),
114-
sensealg = SciMLSensitivity.InterpolatingAdjoint(),
115-
),
116-
)
117-
11889
function diff_loss_tn2(x::Any)
11990
return ContinuousNormalizingFlows.loss(
12091
icnf2,
@@ -134,12 +105,47 @@ function diff_loss_tt2(x::Any)
134105
)
135106
end
136107

108+
diff_loss_tn(ps)
109+
diff_loss_tt(ps)
110+
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
111+
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)
112+
137113
diff_loss_tn2(ps)
138114
diff_loss_tt2(ps)
139115
DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps)
140116
DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps)
117+
141118
GC.gc()
142119

120+
SUITE = BenchmarkTools.BenchmarkGroup()
121+
122+
SUITE["main"] = BenchmarkTools.BenchmarkGroup(["package", "simple"])
123+
124+
SUITE["main"]["no_inplace"] = BenchmarkTools.BenchmarkGroup(["no_inplace"])
125+
SUITE["main"]["inplace"] = BenchmarkTools.BenchmarkGroup(["inplace"])
126+
127+
SUITE["main"]["no_inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
128+
SUITE["main"]["no_inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])
129+
130+
SUITE["main"]["inplace"]["direct"] = BenchmarkTools.BenchmarkGroup(["direct"])
131+
SUITE["main"]["inplace"]["AD-1-order"] = BenchmarkTools.BenchmarkGroup(["gradient"])
132+
133+
SUITE["main"]["no_inplace"]["direct"]["train"] =
134+
BenchmarkTools.@benchmarkable diff_loss_tn(ps)
135+
SUITE["main"]["no_inplace"]["direct"]["test"] =
136+
BenchmarkTools.@benchmarkable diff_loss_tt(ps)
137+
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
138+
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
139+
diff_loss_tn,
140+
ADTypes.AutoZygote(),
141+
ps,
142+
)
143+
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
144+
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
145+
diff_loss_tt,
146+
ADTypes.AutoZygote(),
147+
ps,
148+
)
143149
SUITE["main"]["inplace"]["direct"]["train"] =
144150
BenchmarkTools.@benchmarkable diff_loss_tn2(ps)
145151
SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_loss_tt2(ps)

examples/usage.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ model = ICNFModel(
5454
icnf;
5555
optimizers = (Adam(),),
5656
adtype = AutoZygote(),
57-
batch_size = 512,
57+
batchsize = 512,
5858
sol_kwargs = (; progress = true, epochs = 300), # pass to the solver
5959
)
6060
mach = machine(model, df)

src/ContinuousNormalizingFlows.jl

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import ADTypes,
2525
SciMLSensitivity,
2626
ScientificTypesBase,
2727
Statistics,
28+
WeightInitializers,
2829
Zygote
2930

3031
export construct,
@@ -54,13 +55,10 @@ export construct,
5455
include(joinpath("layers", "cond_layer.jl"))
5556
include(joinpath("layers", "planar_layer.jl"))
5657

57-
include("types.jl")
58-
59-
include("base_icnf.jl")
60-
61-
include("icnf.jl")
62-
63-
include("utils.jl")
58+
include(joinpath("core", "types.jl"))
59+
include(joinpath("core", "base_icnf.jl"))
60+
include(joinpath("core", "icnf.jl"))
61+
include(joinpath("core", "utils.jl"))
6462

6563
include(joinpath("exts", "mlj_ext", "core.jl"))
6664
include(joinpath("exts", "mlj_ext", "core_icnf.jl"))
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,12 @@ function n_augment_input(::AbstractICNF)
9494
end
9595

9696
function steer_tspan(
97-
icnf::AbstractICNF{T, <:ComputeMode, INPLACE, COND, AUGMENTED, true},
97+
icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, COND, AUGMENTED, true},
9898
::TrainMode,
99-
) where {T <: AbstractFloat, INPLACE, COND, AUGMENTED}
99+
) where {INPLACE, COND, AUGMENTED}
100100
t₀, t₁ = icnf.tspan
101101
Δt = abs(t₁ - t₀)
102-
r = convert(T, rand(icnf.rng, icnf.steerdist))
102+
r = oftype(t₁, rand(icnf.rng, icnf.steerdist))
103103
t₁_new = muladd(Δt, r, t₁)
104104
return (t₀, t₁_new)
105105
end
@@ -504,12 +504,12 @@ function loss(
504504
end
505505

506506
function make_ode_func(
507-
icnf::AbstractICNF{T, CM, INPLACE},
507+
icnf::AbstractICNF{T},
508508
mode::Mode,
509509
nn::LuxCore.AbstractLuxLayer,
510510
st::NamedTuple,
511511
ϵ::AbstractVecOrMat{T},
512-
) where {T <: AbstractFloat, CM, INPLACE}
512+
) where {T <: AbstractFloat}
513513
function ode_func(u::Any, p::Any, t::Any)
514514
return augmented_f(u, p, t, icnf, mode, nn, st, ϵ)
515515
end
@@ -521,19 +521,19 @@ function make_ode_func(
521521
return ode_func
522522
end
523523

524-
function (icnf::AbstractICNF{T, CM, INPLACE, false})(
524+
function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, false})(
525525
xs::AbstractVecOrMat,
526526
ps::Any,
527527
st::NamedTuple,
528-
) where {T, CM, INPLACE}
528+
) where {INPLACE}
529529
return first(inference(icnf, TrainMode(), xs, ps, st)), st
530530
end
531531

532-
function (icnf::AbstractICNF{T, CM, INPLACE, true})(
532+
function (icnf::AbstractICNF{<:AbstractFloat, <:ComputeMode, INPLACE, true})(
533533
xs_ys::Tuple,
534534
ps::Any,
535535
st::NamedTuple,
536-
) where {T, CM, INPLACE}
536+
) where {INPLACE}
537537
xs, ys = xs_ys
538538
return first(inference(icnf, TrainMode(), xs, ys, ps, st)), st
539539
end
File renamed without changes.
File renamed without changes.
File renamed without changes.

src/exts/dist_ext/core_cond_icnf.jl

Lines changed: 51 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -15,51 +15,66 @@ function CondICNFDist(
1515
return CondICNFDist(mach.model.m, mode, ys, ps, st)
1616
end
1717

18-
function Distributions._logpdf(d::CondICNFDist, x::AbstractVector{<:Real})
19-
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
20-
first(inference(d.m, d.mode, x, d.ys, d.ps, d.st))
21-
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
22-
@warn "to compute by matrices, data should be a matrix."
23-
first(Distributions._logpdf(d, hcat(x)))
24-
else
25-
error("Not Implemented")
26-
end
18+
function Distributions._logpdf(
19+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
20+
x::AbstractVector{<:Real},
21+
)
22+
return first(inference(d.m, d.mode, x, d.ys, d.ps, d.st))
23+
end
24+
25+
function Distributions._logpdf(
26+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
27+
x::AbstractVector{<:Real},
28+
)
29+
@warn "to compute by matrices, data should be a matrix."
30+
return first(Distributions._logpdf(d, hcat(x)))
31+
end
32+
33+
function Distributions._logpdf(
34+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
35+
A::AbstractMatrix{<:Real},
36+
)
37+
@warn "to compute by vectors, data should be a vector."
38+
return Distributions._logpdf.(d, collect(collect.(eachcol(A))))
2739
end
28-
function Distributions._logpdf(d::CondICNFDist, A::AbstractMatrix{<:Real})
29-
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
30-
@warn "to compute by vectors, data should be a vector."
31-
Distributions._logpdf.(d, collect(collect.(eachcol(A))))
32-
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
33-
first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st))
34-
else
35-
error("Not Implemented")
36-
end
40+
41+
function Distributions._logpdf(
42+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
43+
A::AbstractMatrix{<:Real},
44+
)
45+
return first(inference(d.m, d.mode, A, d.ys[:, begin:size(A, 2)], d.ps, d.st))
3746
end
47+
3848
function Distributions._rand!(
3949
rng::Random.AbstractRNG,
40-
d::CondICNFDist,
50+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
4151
x::AbstractVector{<:Real},
4252
)
43-
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
44-
x .= generate(d.m, d.mode, d.ys, d.ps, d.st)
45-
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
46-
@warn "to compute by matrices, data should be a matrix."
47-
x .= Distributions._rand!(rng, d, hcat(x))
48-
else
49-
error("Not Implemented")
50-
end
53+
return x .= generate(d.m, d.mode, d.ys, d.ps, d.st)
5154
end
55+
56+
function Distributions._rand!(
57+
rng::Random.AbstractRNG,
58+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
59+
x::AbstractVector{<:Real},
60+
)
61+
@warn "to compute by matrices, data should be a matrix."
62+
return x .= Distributions._rand!(rng, d, hcat(x))
63+
end
64+
65+
function Distributions._rand!(
66+
rng::Random.AbstractRNG,
67+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:VectorMode}},
68+
A::AbstractMatrix{<:Real},
69+
)
70+
@warn "to compute by vectors, data should be a vector."
71+
return A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...)
72+
end
73+
5274
function Distributions._rand!(
5375
rng::Random.AbstractRNG,
54-
d::CondICNFDist,
76+
d::CondICNFDist{<:AbstractICNF{<:AbstractFloat, <:MatrixMode}},
5577
A::AbstractMatrix{<:Real},
5678
)
57-
return if d.m isa AbstractICNF{<:AbstractFloat, <:VectorMode}
58-
@warn "to compute by vectors, data should be a vector."
59-
A .= hcat(Distributions._rand!.(rng, d, collect(collect.(eachcol(A))))...)
60-
elseif d.m isa AbstractICNF{<:AbstractFloat, <:MatrixMode}
61-
A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2))
62-
else
63-
error("Not Implemented")
64-
end
79+
return A .= generate(d.m, d.mode, d.ys[:, begin:size(A, 2)], d.ps, d.st, size(A, 2))
6580
end

0 commit comments

Comments
 (0)