Skip to content

Commit 0c788e4

Browse files
committed
[POC] Allow array in nonlinear expressions
1 parent 0b42e5f commit 0c788e4

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/nlp_expr.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,8 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
575575
for i in length(f.args):-1:1
576576
if f.args[i] isa GenericNonlinearExpr{V}
577577
push!(stack, (ret, i, f.args[i]))
578+
elseif f.args[i] isa AbstractArray
579+
ret.args[i] = moi_function.(f.args[i])
578580
else
579581
ret.args[i] = moi_function(f.args[i])
580582
end
@@ -586,6 +588,8 @@ function moi_function(f::GenericNonlinearExpr{V}) where {V}
586588
for j in length(arg.args):-1:1
587589
if arg.args[j] isa GenericNonlinearExpr{V}
588590
push!(stack, (child, j, arg.args[j]))
591+
elseif arg.args[j] isa AbstractArray
592+
child.args[j] = moi_function.(arg.args[j])
589593
else
590594
child.args[j] = moi_function(arg.args[j])
591595
end
@@ -611,6 +615,8 @@ function jump_function(model::GenericModel, f::MOI.ScalarNonlinearFunction)
611615
for child in reverse(arg.args)
612616
push!(stack, (new_ret, child))
613617
end
618+
elseif arg isa AbstractArray
619+
push!(parent.args, jump_function.(model, arg))
614620
else
615621
push!(parent.args, jump_function(model, arg))
616622
end

test/test_nlp_expr.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
module TestNLPExpr
77

88
using JuMP
9+
using LinearAlgebra
910
using Test
1011

1112
import LinearAlgebra
@@ -1232,4 +1233,14 @@ function test_extension_euler_to_exp(
12321233
return
12331234
end
12341235

1236+
function test_array()
1237+
model = Model()
1238+
@variable(model, x)
1239+
op_norm = NonlinearOperator(:det, det)
1240+
@objective(model, Min, op_norm([x]))
1241+
f = MOI.get(model, MOI.ObjectiveFunction{MOI.ScalarNonlinearFunction}())
1242+
@test f.head == :norm
1243+
@test f.args == [[index(x)]]
1244+
end
1245+
12351246
end # module

0 commit comments

Comments
 (0)