Skip to content

Commit f0fc1ea

Browse files
authored
Add support for GalacticOptim 3 and fix test errors (#1834)
1 parent 5c8b428 commit f0fc1ea

File tree

5 files changed

+42
-18
lines changed

5 files changed

+42
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "Turing"
22
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
3-
version = "0.21.3"
3+
version = "0.21.4"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/modes/ModeEstimation.jl

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module ModeEstimation
33
using ..Turing
44
using Bijectors
55
using Random
6-
using SciMLBase: OptimizationFunction, OptimizationProblem, AbstractADType
6+
using SciMLBase: OptimizationFunction, OptimizationProblem, AbstractADType, NoAD
77

88
using DynamicPPL
99
using DynamicPPL: Model, AbstractContext, VarInfo, VarName,
@@ -291,24 +291,47 @@ function optim_objective(model::DynamicPPL.Model, estimator::Union{MLE, MAP}; co
291291
end
292292

293293

294-
function optim_function(model::DynamicPPL.Model, estimator::Union{MLE, MAP}; constrained::Bool=true, autoad::Union{Nothing, AbstractADType}=nothing)
294+
function optim_function(
295+
model::Model,
296+
estimator::Union{MLE, MAP};
297+
constrained::Bool=true,
298+
autoad::Union{Nothing, AbstractADType}=NoAD(),
299+
)
300+
if autoad === nothing
301+
Base.depwarn("the use of `autoad=nothing` is deprecated, please use `autoad=SciMLBase.NoAD()`", :optim_function)
302+
end
303+
295304
obj, init, t = optim_objective(model, estimator; constrained=constrained)
296305

297-
l(x,p) = obj(x)
298-
f = isa(autoad, AbstractADType) ? OptimizationFunction(l, autoad) : OptimizationFunction(l; grad = (G,x,p) -> obj(nothing, G, nothing, x), hess = (H,x,p) -> obj(nothing, nothing, H, x))
306+
l(x, _) = obj(x)
307+
f = if autoad isa AbstractADType && autoad !== NoAD()
308+
OptimizationFunction(l, autoad)
309+
else
310+
OptimizationFunction(
311+
l;
312+
grad = (G,x,p) -> obj(nothing, G, nothing, x),
313+
hess = (H,x,p) -> obj(nothing, nothing, H, x),
314+
)
315+
end
299316

300317
return (func=f, init=init, transform = t)
301318
end
302319

303320

304-
function optim_problem(model::DynamicPPL.Model, estimator::Union{MAP, MLE}; constrained::Bool=true, init_theta=nothing, autoad::Union{Nothing, AbstractADType}=nothing, kwargs...)
305-
f = optim_function(model, estimator; constrained=constrained, autoad=autoad)
306-
307-
init_theta = init_theta === nothing ? f.init() : f.init(init_theta)
321+
function optim_problem(
322+
model::Model,
323+
estimator::Union{MAP, MLE};
324+
constrained::Bool=true,
325+
init_theta=nothing,
326+
autoad::Union{Nothing, AbstractADType}=NoAD(),
327+
kwargs...,
328+
)
329+
f, init, transform = optim_function(model, estimator; constrained=constrained, autoad=autoad)
308330

309-
prob = OptimizationProblem(f.func, init_theta, nothing; kwargs...)
331+
u0 = init_theta === nothing ? init() : init(init_theta)
332+
prob = OptimizationProblem(f, u0; kwargs...)
310333

311-
return (prob=prob, init=f.init, transform = f.transform)
334+
return (; prob, init, transform)
312335
end
313336

314337
end

test/Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
1212
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1313
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1414
GalacticOptim = "a75be94c-b780-496d-a8a9-0878b188d577"
15+
GalacticOptimJL = "9d3c5eb1-403b-401b-8c0f-c11105342e6b"
1516
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1617
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
1718
Memoization = "6fafb56a-5788-4b4e-91ca-c0cea6611c73"
@@ -43,7 +44,8 @@ DynamicHMC = "2.1.6, 3.0"
4344
DynamicPPL = "0.19.1"
4445
FiniteDifferences = "0.10.8, 0.11, 0.12"
4546
ForwardDiff = "0.10.12"
46-
GalacticOptim = "0.4, 1, 2"
47+
GalacticOptim = "3"
48+
GalacticOptimJL = "0.1"
4749
MCMCChains = "5"
4850
Memoization = "0.1.4"
4951
NamedArrays = "0.9.4"

test/modes/ModeEstimation.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
1-
2-
31
@testset "ModeEstimation.jl" begin
42
@testset "gdemo" begin
53
@testset "MLE" begin
64
Random.seed!(222)
75
true_value = [0.0625, 1.75]
86

97
f1 = optim_function(gdemo_default, MLE();constrained=false)
10-
p1 = GalacticOptim.OptimizationProblem(f1.func, f1.init(true_value), nothing)
8+
p1 = OptimizationProblem(f1.func, f1.init(true_value))
119

1210
p2 = optim_objective(gdemo_default, MLE();constrained=false)
1311

@@ -39,7 +37,7 @@
3937
true_value = [49 / 54, 7 / 6]
4038

4139
f1 = optim_function(gdemo_default, MAP();constrained=false)
42-
p1 = GalacticOptim.OptimizationProblem(f1.func, f1.init(true_value), nothing)
40+
p1 = OptimizationProblem(f1.func, f1.init(true_value))
4341

4442
p2 = optim_objective(gdemo_default, MAP();constrained=false)
4543

@@ -73,7 +71,7 @@
7371
ub = [2.0, 2.0]
7472

7573
f1 = optim_function(gdemo_default, MLE();constrained=true)
76-
p1 = GalacticOptim.OptimizationProblem(f1.func, f1.init(true_value), nothing; lb=lb, ub=ub)
74+
p1 = GalacticOptim.OptimizationProblem(f1.func, f1.init(true_value); lb=lb, ub=ub)
7775

7876
p2 = optim_objective(gdemo_default, MLE();constrained=true)
7977

@@ -101,7 +99,7 @@
10199
ub = [2.0, 2.0]
102100

103101
f1 = optim_function(gdemo_default, MAP();constrained=true)
104-
p1 = GalacticOptim.OptimizationProblem(f1.func, f1.init(true_value), nothing; lb=lb, ub=ub)
102+
p1 = GalacticOptim.OptimizationProblem(f1.func, f1.init(true_value); lb=lb, ub=ub)
105103

106104
p2 = optim_objective(gdemo_default, MAP();constrained=true)
107105

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ using DistributionsAD
77
using FiniteDifferences
88
using ForwardDiff
99
using GalacticOptim
10+
using GalacticOptimJL
1011
using MCMCChains
1112
using Memoization
1213
using NamedArrays

0 commit comments

Comments
 (0)