1- import NLSolversBase:
2- value, value!, value!!, gradient, gradient!, value_gradient!, value_gradient!!
31# ###### FIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIIX THE MIDDLE OF BOX CASE THAT WAS THERE
42mutable struct BarrierWrapper{TO,TB,Tm,TF,TDF} <: AbstractObjective
53 obj:: TO
@@ -64,24 +62,24 @@ function _barrier_term_gradient(x::T, l, u) where {T}
6462 end
6563 return g
6664end
67- function value_gradient! (bb:: BoxBarrier , g, x)
65+ function NLSolversBase . value_gradient! (bb:: BoxBarrier , g, x)
6866 g .= _barrier_term_gradient .(x, bb. lower, bb. upper)
6967 value (bb, x)
7068end
71- function gradient (bb:: BoxBarrier , g, x)
69+ function NLSolversBase . gradient (bb:: BoxBarrier , g, x)
7270 g = copy (g)
7371 g .= _barrier_term_gradient .(x, bb. lower, bb. upper)
7472end
7573# Wrappers
76- function value!! (bw:: BarrierWrapper , x)
74+ function NLSolversBase . value!! (bw:: BarrierWrapper , x)
7775 bw. Fb = value (bw. b, x)
7876 bw. Ftotal = bw. mu * bw. Fb
7977 if in_box (bw, x)
8078 value!! (bw. obj, x)
8179 bw. Ftotal += value (bw. obj)
8280 end
8381end
84- function value_gradient!! (bw:: BarrierWrapper , x)
82+ function NLSolversBase . value_gradient!! (bw:: BarrierWrapper , x)
8583 bw. Fb = value (bw. b, x)
8684 bw. Ftotal = bw. mu * bw. Fb
8785 bw. DFb .= _barrier_term_gradient .(x, bw. b. lower, bw. b. upper)
@@ -93,7 +91,7 @@ function value_gradient!!(bw::BarrierWrapper, x)
9391 end
9492
9593end
96- function value_gradient! (bb:: BarrierWrapper , x)
94+ function NLSolversBase . value_gradient! (bb:: BarrierWrapper , x)
9795 bb. DFb .= _barrier_term_gradient .(x, bb. b. lower, bb. b. upper)
9896 bb. Fb = value (bb. b, x)
9997 bb. DFtotal .= bb. mu .* bb. DFb
@@ -105,9 +103,9 @@ function value_gradient!(bb::BarrierWrapper, x)
105103 bb. Ftotal += value (bb. obj)
106104 end
107105end
108- value (bb:: BoxBarrier , x) =
106+ NLSolversBase . value (bb:: BoxBarrier , x) =
109107 mapreduce (x -> _barrier_term_value (x... ), + , zip (x, bb. lower, bb. upper))
110- function value! (obj:: BarrierWrapper , x)
108+ function NLSolversBase . value! (obj:: BarrierWrapper , x)
111109 obj. Fb = value (obj. b, x)
112110 obj. Ftotal = obj. mu * obj. Fb
113111 if in_box (obj, x)
@@ -116,20 +114,20 @@ function value!(obj::BarrierWrapper, x)
116114 end
117115 obj. Ftotal
118116end
119- value (obj:: BarrierWrapper ) = obj. Ftotal
120- function value (obj:: BarrierWrapper , x)
117+ NLSolversBase . value (obj:: BarrierWrapper ) = obj. Ftotal
118+ function NLSolversBase . value (obj:: BarrierWrapper , x)
121119 F = obj. mu * value (obj. b, x)
122120 if in_box (obj, x)
123121 F += value (obj. obj, x)
124122 end
125123 F
126124end
127- function gradient! (obj:: BarrierWrapper , x)
125+ function NLSolversBase . gradient! (obj:: BarrierWrapper , x)
128126 gradient! (obj. obj, x)
129127 obj. DFb .= gradient (obj. b, obj. DFb, x) # this should just be inplace?
130128 obj. DFtotal .= gradient (obj. obj) .+ obj. mu * obj. Fb
131129end
132- gradient (obj:: BarrierWrapper ) = obj. DFtotal
130+ NLSolversBase . gradient (obj:: BarrierWrapper ) = obj. DFtotal
133131
134132# this mutates mu but not the gradients
135133# Super unsafe in that it depends on x_df being correct!
@@ -299,7 +297,7 @@ function optimize(
299297 initial_x:: AbstractArray ,
300298 F:: Fminbox = Fminbox (),
301299 options:: Options = Options ();
302- inplace = true ,
300+ inplace:: Bool = true ,
303301)
304302
305303 g! = inplace ? g : (G, x) -> copyto! (G, g (x))
@@ -536,12 +534,13 @@ function optimize(
536534 end
537535 results = optimize (dfbox, x, _optimizer, options, state)
538536 stopped_by_callback = results. stopped_by. callback
539- dfbox. obj. f_calls[1 ] = 0
537+ # TODO : Define an API (e.g. `reset_calls!`?) in NLSolversBase
538+ dfbox. obj. f_calls = 0
540539 if hasfield (typeof (dfbox. obj), :df_calls )
541- dfbox. obj. df_calls[ 1 ] = 0
540+ dfbox. obj. df_calls = 0
542541 end
543542 if hasfield (typeof (dfbox. obj), :h_calls )
544- dfbox. obj. h_calls[ 1 ] = 0
543+ dfbox. obj. h_calls = 0
545544 end
546545 copyto! (x, minimizer (results))
547546 boxdist = Base. minimum (((xi, li, ui),) -> min (xi - li, ui - xi), zip (x, l, u)) # Base.minimum !== minimum
@@ -613,12 +612,13 @@ function optimize(
613612 resultsnew = optimize (dfbox, x, _optimizer, options, state)
614613 stopped_by_callback = resultsnew. stopped_by. callback
615614 append! (results, resultsnew)
616- dfbox. obj. f_calls[1 ] = 0
615+ # TODO : Define an API (e.g. `reset_calls!`?) in NLSolversBase
616+ dfbox. obj. f_calls = 0
617617 if hasfield (typeof (dfbox. obj), :df_calls )
618- dfbox. obj. df_calls[ 1 ] = 0
618+ dfbox. obj. df_calls = 0
619619 end
620620 if hasfield (typeof (dfbox. obj), :h_calls )
621- dfbox. obj. h_calls[ 1 ] = 0
621+ dfbox. obj. h_calls = 0
622622 end
623623 copyto! (x, minimizer (results))
624624 boxdist = Base. minimum (((xi, li, ui),) -> min (xi - li, ui - xi), zip (x, l, u)) # Base.minimum !== minimum
0 commit comments