@@ -3,7 +3,7 @@ module ModeEstimation
3
3
using .. Turing
4
4
using Bijectors
5
5
using Random
6
- using SciMLBase: OptimizationFunction, OptimizationProblem, AbstractADType
6
+ using SciMLBase: OptimizationFunction, OptimizationProblem, AbstractADType, NoAD
7
7
8
8
using DynamicPPL
9
9
using DynamicPPL: Model, AbstractContext, VarInfo, VarName,
@@ -291,24 +291,47 @@ function optim_objective(model::DynamicPPL.Model, estimator::Union{MLE, MAP}; co
291
291
end
292
292
293
293
294
- function optim_function (model:: DynamicPPL.Model , estimator:: Union{MLE, MAP} ; constrained:: Bool = true , autoad:: Union{Nothing, AbstractADType} = nothing )
294
+ function optim_function (
295
+ model:: Model ,
296
+ estimator:: Union{MLE, MAP} ;
297
+ constrained:: Bool = true ,
298
+ autoad:: Union{Nothing, AbstractADType} = NoAD (),
299
+ )
300
+ if autoad === nothing
301
+ Base. depwarn (" the use of `autoad=nothing` is deprecated, please use `autoad=SciMLBase.NoAD()`" , :optim_function )
302
+ end
303
+
295
304
obj, init, t = optim_objective (model, estimator; constrained= constrained)
296
305
297
- l (x,p) = obj (x)
298
- f = isa (autoad, AbstractADType) ? OptimizationFunction (l, autoad) : OptimizationFunction (l; grad = (G,x,p) -> obj (nothing , G, nothing , x), hess = (H,x,p) -> obj (nothing , nothing , H, x))
306
+ l (x, _) = obj (x)
307
+ f = if autoad isa AbstractADType && autoad != = NoAD ()
308
+ OptimizationFunction (l, autoad)
309
+ else
310
+ OptimizationFunction (
311
+ l;
312
+ grad = (G,x,p) -> obj (nothing , G, nothing , x),
313
+ hess = (H,x,p) -> obj (nothing , nothing , H, x),
314
+ )
315
+ end
299
316
300
317
return (func= f, init= init, transform = t)
301
318
end
302
319
303
320
304
- function optim_problem (model:: DynamicPPL.Model , estimator:: Union{MAP, MLE} ; constrained:: Bool = true , init_theta= nothing , autoad:: Union{Nothing, AbstractADType} = nothing , kwargs... )
305
- f = optim_function (model, estimator; constrained= constrained, autoad= autoad)
306
-
307
- init_theta = init_theta === nothing ? f. init () : f. init (init_theta)
321
+ function optim_problem (
322
+ model:: Model ,
323
+ estimator:: Union{MAP, MLE} ;
324
+ constrained:: Bool = true ,
325
+ init_theta= nothing ,
326
+ autoad:: Union{Nothing, AbstractADType} = NoAD (),
327
+ kwargs... ,
328
+ )
329
+ f, init, transform = optim_function (model, estimator; constrained= constrained, autoad= autoad)
308
330
309
- prob = OptimizationProblem (f. func, init_theta, nothing ; kwargs... )
331
+ u0 = init_theta === nothing ? init () : init (init_theta)
332
+ prob = OptimizationProblem (f, u0; kwargs... )
310
333
311
- return (prob= prob , init= f . init, transform = f . transform)
334
+ return (; prob, init, transform)
312
335
end
313
336
314
337
end
0 commit comments