Skip to content

Commit 65318f3

Browse files
some fixes
Signed-off-by: AdityaPandeyCN <[email protected]>
1 parent c65d083 commit 65318f3

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

lib/OptimizationSciPy/src/OptimizationSciPy.jl

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,7 @@ function SciMLBase.__solve(cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C})
479479
end
480480
θ_vec = [θ]
481481
x = cache.f(θ_vec, cache.p)
482+
x = isa(x, Tuple) ? x : (x,)
482483
opt_state = Optimization.OptimizationState(u = θ_vec, objective = x[1])
483484
if cache.callback(opt_state, x...)
484485
error("Optimization halted by callback")
@@ -623,6 +624,7 @@ function SciMLBase.__solve(cache::OptimizationCache{F,RC,LB,UB,LC,UC,S,O,D,P,C})
623624
end
624625
θ_vec = [θ]
625626
x = cache.f(θ_vec, cache.p)
627+
x = isa(x, Tuple) ? x : (x,)
626628
opt_state = Optimization.OptimizationState(u = θ_vec, objective = x[1])
627629
if cache.callback(opt_state, x...)
628630
error("Optimization halted by callback")
@@ -1355,8 +1357,11 @@ function _create_loss(cache; vector_output::Bool = false)
13551357
end
13561358
θ_julia = ensure_julia_array(θ, eltype(cache.u0))
13571359
x = cache.f(θ_julia, cache.p)
1358-
x = isa(x, Tuple) ? x[1] : x
1359-
x = isa(x, Number) ? [x] : x
1360+
if isa(x, Tuple)
1361+
x = x
1362+
elseif isa(x, Number)
1363+
x = (x,)
1364+
end
13601365
opt_state = Optimization.OptimizationState(u = θ_julia, objective = sum(abs2, x))
13611366
if cache.callback(opt_state, x...)
13621367
error("Optimization halted by callback")
@@ -1372,6 +1377,11 @@ function _create_loss(cache; vector_output::Bool = false)
13721377
end
13731378
θ_julia = ensure_julia_array(θ, eltype(cache.u0))
13741379
x = cache.f(θ_julia, cache.p)
1380+
if isa(x, Tuple)
1381+
x = x
1382+
elseif isa(x, Number)
1383+
x = (x,)
1384+
end
13751385
opt_state = Optimization.OptimizationState(u = θ_julia, objective = x[1])
13761386
if cache.callback(opt_state, x...)
13771387
error("Optimization halted by callback")

0 commit comments

Comments
 (0)