Skip to content

Commit 3f1c176

Browse files
fix: use correct parameter object in OptimizationBaseMTKExt
1 parent 87e358f commit 3f1c176

File tree

1 file changed

+32
-29
lines changed

1 file changed

+32
-29
lines changed

lib/OptimizationBase/ext/OptimizationMTKExt.jl

Lines changed: 32 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,14 @@ function OptimizationBase.instantiate_function(
2121
num_cons))))
2222
#sys = ModelingToolkit.structural_simplify(sys)
2323
# don't need to pass `x` or `p` since they're defaults now
24-
f = OptimizationProblem(sys, nothing; grad = g, hess = h,
24+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
2525
sparse = true, cons_j = cons_j, cons_h = cons_h,
26-
cons_sparse = true).f
26+
cons_sparse = true)
27+
f = mtkprob.f
2728

28-
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
29+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
2930

30-
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
31+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
3132

3233
hv = function (H, θ, v, args...)
3334
res = (eltype(θ)).(f.hess_prototype)
@@ -36,9 +37,9 @@ function OptimizationBase.instantiate_function(
3637
end
3738

3839
if !isnothing(f.cons)
39-
cons = (res, θ) -> f.cons(res, θ, p)
40-
cons_j = (J, θ) -> f.cons_j(J, θ, p)
41-
cons_h = (res, θ) -> f.cons_h(res, θ, p)
40+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
41+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
42+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
4243
else
4344
cons = nothing
4445
cons_j = nothing
@@ -72,24 +73,24 @@ function OptimizationBase.instantiate_function(
7273
num_cons))))
7374
#sys = ModelingToolkit.structural_simplify(sys)
7475
# don't need to pass `x` or `p` since they're defaults now
75-
f = OptimizationProblem(sys, nothing; grad = g, hess = h,
76+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
7677
sparse = true, cons_j = cons_j, cons_h = cons_h,
77-
cons_sparse = true).f
78+
cons_sparse = true)
79+
f = mtkprob.f
7880

79-
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
81+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
8082

81-
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
83+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
8284

8385
hv = function (H, θ, v, args...)
8486
res = (eltype(θ)).(f.hess_prototype)
8587
hess(res, θ, args...)
8688
H .= res * v
8789
end
88-
8990
if !isnothing(f.cons)
90-
cons = (res, θ) -> f.cons(res, θ, cache.p)
91-
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
92-
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
91+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
92+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
93+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
9394
else
9495
cons = nothing
9596
cons_j = nothing
@@ -121,13 +122,14 @@ function OptimizationBase.instantiate_function(
121122
num_cons))))
122123
#sys = ModelingToolkit.structural_simplify(sys)
123124
# don't need to pass `x` or `p` since they're defaults now
124-
f = OptimizationProblem(sys, nothing; grad = g, hess = h,
125+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
125126
sparse = false, cons_j = cons_j, cons_h = cons_h,
126-
cons_sparse = false).f
127+
cons_sparse = false)
128+
f = mtkprob.f
127129

128-
grad = (G, θ, args...) -> f.grad(G, θ, p, args...)
130+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
129131

130-
hess = (H, θ, args...) -> f.hess(H, θ, p, args...)
132+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
131133

132134
hv = function (H, θ, v, args...)
133135
res = ArrayInterface.zeromatrix(θ)
@@ -136,9 +138,9 @@ function OptimizationBase.instantiate_function(
136138
end
137139

138140
if !isnothing(f.cons)
139-
cons = (res, θ) -> f.cons(res, θ, p)
140-
cons_j = (J, θ) -> f.cons_j(J, θ, p)
141-
cons_h = (res, θ) -> f.cons_h(res, θ, p)
141+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
142+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
143+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
142144
else
143145
cons = nothing
144146
cons_j = nothing
@@ -172,13 +174,14 @@ function OptimizationBase.instantiate_function(
172174
num_cons))))
173175
#sys = ModelingToolkit.structural_simplify(sys)
174176
# don't need to pass `x` or `p` since they're defaults now
175-
f = OptimizationProblem(sys, nothing; grad = g, hess = h,
177+
mtkprob = OptimizationProblem(sys, nothing; grad = g, hess = h,
176178
sparse = false, cons_j = cons_j, cons_h = cons_h,
177-
cons_sparse = false).f
179+
cons_sparse = false)
180+
f = mtkprob.f
178181

179-
grad = (G, θ, args...) -> f.grad(G, θ, cache.p, args...)
182+
grad = (G, θ, args...) -> f.grad(G, θ, mtkprob.p, args...)
180183

181-
hess = (H, θ, args...) -> f.hess(H, θ, cache.p, args...)
184+
hess = (H, θ, args...) -> f.hess(H, θ, mtkprob.p, args...)
182185

183186
hv = function (H, θ, v, args...)
184187
res = ArrayInterface.zeromatrix(θ)
@@ -187,9 +190,9 @@ function OptimizationBase.instantiate_function(
187190
end
188191

189192
if !isnothing(f.cons)
190-
cons = (res, θ) -> f.cons(res, θ, cache.p)
191-
cons_j = (J, θ) -> f.cons_j(J, θ, cache.p)
192-
cons_h = (res, θ) -> f.cons_h(res, θ, cache.p)
193+
cons = (res, θ) -> f.cons(res, θ, mtkprob.p)
194+
cons_j = (J, θ) -> f.cons_j(J, θ, mtkprob.p)
195+
cons_h = (res, θ) -> f.cons_h(res, θ, mtkprob.p)
193196
else
194197
cons = nothing
195198
cons_j = nothing

0 commit comments

Comments
 (0)