Skip to content

Commit 9795675

Browse files
authored
back to zygote (#461)
* back to zygote * go by reverse * revert to zygote buffer * more back to zygote
1 parent 0b4a951 commit 9795675

File tree

15 files changed

+59
-119
lines changed

15 files changed

+59
-119
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
1313
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
1414
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1515
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
16-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1716
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
1817
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1918
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
@@ -31,6 +30,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
3130
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
3231
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
3332
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
33+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
3434

3535
[weakdeps]
3636
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
@@ -49,7 +49,6 @@ Dates = "1"
4949
DifferentiationInterface = "0.6"
5050
Distributions = "0.25"
5151
DistributionsAD = "0.6"
52-
Enzyme = "0.13"
5352
FillArrays = "1"
5453
LinearAlgebra = "1"
5554
Lux = "1"
@@ -67,4 +66,5 @@ SciMLBase = "2"
6766
SciMLSensitivity = "7"
6867
ScientificTypesBase = "3"
6968
Statistics = "1"
69+
Zygote = "0.6, 0.7"
7070
julia = "1.10"

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,14 +43,14 @@ n_in = nvars + naugs # with augmentation
4343
n = 1024
4444

4545
# Model
46-
using ContinuousNormalizingFlows, Lux, ADTypes #, Enzyme, CUDA, ComputationalResources
46+
using ContinuousNormalizingFlows, Lux, ADTypes #, Zygote, CUDA, ComputationalResources
4747
nn = Chain(Dense(n_in => 3 * n_in, tanh), Dense(3 * n_in => n_in, tanh))
4848
icnf = construct(
4949
RNODE,
5050
nn,
5151
nvars, # number of variables
5252
naugs; # number of augmented dimensions
53-
# compute_mode = DIVecJacMatrixMode(AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const)), # process data in batches and use Enzyme
53+
# compute_mode = DIVecJacMatrixMode(AutoZygote()), # process data in batches and use Zygote
5454
# inplace = true, # use the inplace version of functions
5555
# resource = CUDALibs(), # process data by GPU
5656
tspan = (0.0f0, 13.0f0), # have bigger time span
@@ -74,13 +74,13 @@ r = rand(data_dist, nvars, n)
7474
r = convert.(Float32, r)
7575

7676
# Fit It
77-
using DataFrames, MLJBase #, Enzyme, ADTypes, OptimizationOptimisers
77+
using DataFrames, MLJBase #, Zygote, ADTypes, OptimizationOptimisers
7878
df = DataFrame(transpose(r), :auto)
7979
model = ICNFModel(
8080
icnf;
8181
# optimizers = (Lion(),),
8282
# n_epochs = 300,
83-
# adtype = AutoEnzyme(; mode = Enzyme.set_runtime_activity(Enzyme.Reverse), function_annotation = Enzyme.Const),
83+
# adtype = AutoZygote(),
8484
# use_batch = true,
8585
# batch_size = 32,
8686
sol_kwargs = (; progress = true,), # pass to the solver

benchmark/Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
33
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
44
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
55
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
6-
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
76
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
87
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
98
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
9+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1010

1111
[compat]
1212
ADTypes = "1"
1313
BenchmarkTools = "1"
1414
ComponentArrays = "0.15"
1515
DifferentiationInterface = "0.6"
16-
Enzyme = "0.13"
1716
Lux = "1"
1817
PkgBenchmark = "0.2"
1918
StableRNGs = "1"
19+
Zygote = "0.6, 0.7"
2020
julia = "1.10"

benchmark/benchmarks.jl

Lines changed: 11 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ import ADTypes,
22
BenchmarkTools,
33
ComponentArrays,
44
DifferentiationInterface,
5-
Enzyme,
65
Lux,
76
PkgBenchmark,
87
StableRNGs,
8+
Zygote,
99
ContinuousNormalizingFlows
1010

1111
SUITE = BenchmarkTools.BenchmarkGroup()
@@ -33,12 +33,7 @@ icnf = ContinuousNormalizingFlows.construct(
3333
nn,
3434
nvars,
3535
naugs;
36-
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
37-
ADTypes.AutoEnzyme(;
38-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
39-
function_annotation = Enzyme.Const,
40-
),
41-
),
36+
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
4237
tspan = (0.0f0, 13.0f0),
4338
steer_rate = 1.0f-1,
4439
λ₃ = 1.0f-2,
@@ -57,22 +52,8 @@ end
5752

5853
diff_loss_tn(ps)
5954
diff_loss_tt(ps)
60-
DifferentiationInterface.gradient(
61-
diff_loss_tn,
62-
ADTypes.AutoEnzyme(;
63-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
64-
function_annotation = Enzyme.Const,
65-
),
66-
ps,
67-
)
68-
DifferentiationInterface.gradient(
69-
diff_loss_tt,
70-
ADTypes.AutoEnzyme(;
71-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
72-
function_annotation = Enzyme.Const,
73-
),
74-
ps,
75-
)
55+
DifferentiationInterface.gradient(diff_loss_tn, ADTypes.AutoZygote(), ps)
56+
DifferentiationInterface.gradient(diff_loss_tt, ADTypes.AutoZygote(), ps)
7657
GC.gc()
7758

7859
SUITE["main"]["no_inplace"]["direct"]["train"] =
@@ -82,19 +63,13 @@ SUITE["main"]["no_inplace"]["direct"]["test"] =
8263
SUITE["main"]["no_inplace"]["AD-1-order"]["train"] =
8364
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
8465
diff_loss_tn,
85-
ADTypes.AutoEnzyme(;
86-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
87-
function_annotation = Enzyme.Const,
88-
),
66+
ADTypes.AutoZygote(),
8967
ps,
9068
)
9169
SUITE["main"]["no_inplace"]["AD-1-order"]["test"] =
9270
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
9371
diff_loss_tt,
94-
ADTypes.AutoEnzyme(;
95-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
96-
function_annotation = Enzyme.Const,
97-
),
72+
ADTypes.AutoZygote(),
9873
ps,
9974
)
10075

@@ -104,12 +79,7 @@ icnf2 = ContinuousNormalizingFlows.construct(
10479
nvars,
10580
naugs;
10681
inplace = true,
107-
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(
108-
ADTypes.AutoEnzyme(;
109-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
110-
function_annotation = Enzyme.Const,
111-
),
112-
),
82+
compute_mode = ContinuousNormalizingFlows.DIVecJacMatrixMode(ADTypes.AutoZygote()),
11383
tspan = (0.0f0, 13.0f0),
11484
steer_rate = 1.0f-1,
11585
λ₃ = 1.0f-2,
@@ -125,22 +95,8 @@ end
12595

12696
diff_loss_tn2(ps)
12797
diff_loss_tt2(ps)
128-
DifferentiationInterface.gradient(
129-
diff_loss_tn2,
130-
ADTypes.AutoEnzyme(;
131-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
132-
function_annotation = Enzyme.Const,
133-
),
134-
ps,
135-
)
136-
DifferentiationInterface.gradient(
137-
diff_loss_tt2,
138-
ADTypes.AutoEnzyme(;
139-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
140-
function_annotation = Enzyme.Const,
141-
),
142-
ps,
143-
)
98+
DifferentiationInterface.gradient(diff_loss_tn2, ADTypes.AutoZygote(), ps)
99+
DifferentiationInterface.gradient(diff_loss_tt2, ADTypes.AutoZygote(), ps)
144100
GC.gc()
145101

146102
SUITE["main"]["inplace"]["direct"]["train"] =
@@ -149,18 +105,12 @@ SUITE["main"]["inplace"]["direct"]["test"] = BenchmarkTools.@benchmarkable diff_
149105
SUITE["main"]["inplace"]["AD-1-order"]["train"] =
150106
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
151107
diff_loss_tn2,
152-
ADTypes.AutoEnzyme(;
153-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
154-
function_annotation = Enzyme.Const,
155-
),
108+
ADTypes.AutoZygote(),
156109
ps,
157110
)
158111
SUITE["main"]["inplace"]["AD-1-order"]["test"] =
159112
BenchmarkTools.@benchmarkable DifferentiationInterface.gradient(
160113
diff_loss_tt2,
161-
ADTypes.AutoEnzyme(;
162-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
163-
function_annotation = Enzyme.Const,
164-
),
114+
ADTypes.AutoZygote(),
165115
ps,
166116
)

src/ContinuousNormalizingFlows.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import ADTypes,
1010
DifferentiationInterface,
1111
Distributions,
1212
DistributionsAD,
13-
Enzyme,
1413
FillArrays,
1514
LinearAlgebra,
1615
Lux,
@@ -27,7 +26,8 @@ import ADTypes,
2726
ScientificTypesBase,
2827
SciMLBase,
2928
SciMLSensitivity,
30-
Statistics
29+
Statistics,
30+
Zygote
3131

3232
export construct,
3333
inference,

src/base_icnf.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,7 @@ function construct(
44
nvars::Int,
55
naugmented::Int = 0;
66
data_type::Type{<:AbstractFloat} = Float32,
7-
compute_mode::ComputeMode = DIVecJacMatrixMode(
8-
ADTypes.AutoEnzyme(;
9-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
10-
function_annotation = Enzyme.Const,
11-
),
12-
),
7+
compute_mode::ComputeMode = DIVecJacMatrixMode(ADTypes.AutoZygote()),
138
inplace::Bool = false,
149
cond::Bool = aicnf <: Union{CondRNODE, CondFFJORD, CondPlanar},
1510
resource::ComputationalResources.AbstractResource = ComputationalResources.CPU1(),

src/exts/mlj_ext/core_cond_icnf.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ function CondICNFModel(
1616
loss::Function = loss;
1717
optimizers::Tuple = (Optimisers.Lion(),),
1818
n_epochs::Int = 300,
19-
adtype::ADTypes.AbstractADType = ADTypes.AutoEnzyme(;
20-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
21-
function_annotation = Enzyme.Const,
22-
),
19+
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
2320
use_batch::Bool = true,
2421
batch_size::Int = 32,
2522
sol_kwargs::NamedTuple = (;),

src/exts/mlj_ext/core_icnf.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@ function ICNFModel(
1616
loss::Function = loss;
1717
optimizers::Tuple = (Optimisers.Lion(),),
1818
n_epochs::Int = 300,
19-
adtype::ADTypes.AbstractADType = ADTypes.AutoEnzyme(;
20-
mode = Enzyme.set_runtime_activity(Enzyme.Reverse),
21-
function_annotation = Enzyme.Const,
22-
),
19+
adtype::ADTypes.AbstractADType = ADTypes.AutoZygote(),
2320
use_batch::Bool = true,
2421
batch_size::Int = 32,
2522
sol_kwargs::NamedTuple = (;),

src/utils.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,14 @@
66
y = f(xs)
77
z = similar(xs)
88
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
9-
res = similar(xs, size(xs, 1), size(xs, 1), size(xs, 2))
9+
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
1010
for i in axes(xs, 1)
1111
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
1212
res[i, :, :] =
1313
only(DifferentiationInterface.pullback(f, icnf.compute_mode.adback, xs, (z,)))
1414
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
1515
end
16-
y, eachslice(res; dims = 3)
16+
y, eachslice(copy(res); dims = 3)
1717
end
1818

1919
@inline function jacobian_batched(
@@ -24,15 +24,15 @@ end
2424
y = f(xs)
2525
z = similar(xs)
2626
ChainRulesCore.@ignore_derivatives fill!(z, zero(T))
27-
res = similar(xs, size(xs, 1), size(xs, 1), size(xs, 2))
27+
res = Zygote.Buffer(xs, size(xs, 1), size(xs, 1), size(xs, 2))
2828
for i in axes(xs, 1)
2929
ChainRulesCore.@ignore_derivatives z[i, :] .= one(T)
3030
res[:, i, :] = only(
3131
DifferentiationInterface.pushforward(f, icnf.compute_mode.adback, xs, (z,)),
3232
)
3333
ChainRulesCore.@ignore_derivatives z[i, :] .= zero(T)
3434
end
35-
y, eachslice(res; dims = 3)
35+
y, eachslice(copy(res); dims = 3)
3636
end
3737

3838
@inline function jacobian_batched(

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
1616
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
1717
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
1818
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
19+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1920

2021
[compat]
2122
ADTypes = "1"
@@ -33,4 +34,5 @@ MLJBase = "1"
3334
SciMLBase = "2"
3435
StableRNGs = "1"
3536
TerminalLoggers = "0.1"
37+
Zygote = "0.6, 0.7"
3638
julia = "1.10"

0 commit comments

Comments
 (0)