Skip to content

Commit 3b3f82b

Browse files
committed
Create expression even when there is no JuMP variables
1 parent b9c96ff commit 3b3f82b

File tree

6 files changed

+89
-59
lines changed

6 files changed

+89
-59
lines changed

src/JuMP_wrapper.jl

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,18 @@ include("operators.jl")
1010

1111
"""
1212
struct IteratorInExpr
13-
iterator::Iterator
13+
iterators::Iterators
1414
index::IteratorIndex
1515
end
1616
17-
Iterator `iterator` with values at index `index`.
17+
Iterator `iterators[index.value]`.
1818
"""
1919
struct IteratorInExpr
2020
iterators::Iterators
2121
index::IteratorIndex
2222
end
2323

24-
function Base.show(io::IO, i::IteratorInExpr)
25-
print(
26-
io,
27-
values_at(i.iterators[i.index.iterator_index], i.index.value_index),
28-
)
29-
print(io, "[i]")
30-
return
31-
end
24+
Base.copy(it::IteratorInExpr) = it
3225

3326
JuMP._is_real(::Union{IteratorInExpr,IteratorIndex}) = true
3427
JuMP.moi_function(i::Union{IteratorInExpr,IteratorIndex}) = i
@@ -55,20 +48,27 @@ function JuMP.jump_function(model, f::FunctionGenerator{F}) where {F}
5548
)
5649
end
5750

58-
_size(expr::ExprGenerator) = getfield.(expr.expr.iterators, :length)
51+
_size(expr::ExprGenerator) = length.(getfield.(expr.expr.iterators, :values))
5952

6053
index_iterators(func, _) = func
6154

6255
function index_iterators(func::IteratorInExpr, index)
6356
idx = func.index
64-
return values_at(iterators[idx.iterator_index], idx.value_index)[index[idx.iterator_index]]
57+
return func.iterators[idx.value].values[index[idx.value]]
6558
end
6659

6760
function index_iterators(func::JuMP.GenericNonlinearExpr, index)
68-
return GenericNonlinearExpr(
69-
func.head,
70-
map(Base.Fix2(index_iterators, index), func.args),
71-
)
61+
args = map(Base.Fix2(index_iterators, index), func.args)
62+
if any(JuMP._has_variable_ref_type, args)
63+
return JuMP.GenericNonlinearExpr(func.head, args)
64+
else
65+
registry = MOI.Nonlinear.OperatorRegistry()
66+
if length(func.args) == 1
67+
MOI.Nonlinear.eval_univariate_function(registry, func.head, args[])
68+
else
69+
MOI.Nonlinear.eval_multivariate_function(registry, func.head, args)
70+
end
71+
end
7272
end
7373

7474
function Base.getindex(expr::ExprGenerator, i::Integer)
@@ -104,10 +104,11 @@ function JuMP.build_constraint(
104104
end
105105

106106
struct IteratedConstraint{
107+
E,
107108
V<:JuMP.GenericVariableRef,
108109
S<:MOI.AbstractVectorSet,
109110
} <: JuMP.AbstractConstraint
110-
func::ExprGenerator{V}
111+
func::ExprGenerator{E,V}
111112
set::S
112113
end
113114

src/MOI_wrapper.jl

Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,37 +6,23 @@
66
import MathOptInterface as MOI
77

88
"""
9-
struct Iterator
10-
length::Int
11-
values::Vector{Float64}
9+
struct Iterator{T}
10+
values::Vector{T}
1211
end
1312
"""
14-
struct Iterator
15-
length::Int
16-
values::Vector{Float64}
17-
function Iterator(length::Int)
18-
return new(length, Float64[])
19-
end
20-
end
21-
22-
function num_values(it::Iterator)
23-
return div(length(it.values), it.length)
24-
end
25-
26-
function values_at(it::Iterator, i)
27-
return it.values[(1+it.length*(i-1)):(it.length*i)]
13+
struct Iterator{T}
14+
values::Vector{T}
2815
end
2916

3017
struct IteratorIndex
31-
iterator_index::Int
32-
value_index::Int
18+
value::Int
3319
end
3420

3521
Base.copy(i::IteratorIndex) = i
3622

3723
struct FunctionGenerator{F} <: MOI.AbstractVectorFunction
3824
func::MOI.ScalarNonlinearFunction
39-
iterators::Vector{Iterator}
25+
iterators::Vector{Iterator} # Slight type instability, we don't have `Iterator{T}`
4026
end
4127

4228
function Base.copy(f::FunctionGenerator{F}) where {F}

src/operators.jl

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ const Iterators = Vector{Iterator}
1010
expr::JuMP.GenericNonlinearExpr{V}
1111
iterators::Vector{Iterator}
1212
end
13+
14+
Represent a `JuMP.GenericNonlinearExpr` containing iterators.
15+
Thanks to this custom type, we can define a custom `JuMP.build_constraint` method
16+
to generate constraint of different types.
1317
"""
1418
struct ExprTemplate{E,V<:JuMP.AbstractVariableRef} <: JuMP.AbstractJuMPScalar
1519
expr::JuMP.GenericNonlinearExpr{V}
@@ -25,6 +29,8 @@ end
2529

2630
JuMP.variable_ref_type(::Type{ExprTemplate{E,V}}) where {E,V} = V
2731

32+
JuMP.check_belongs_to_model(f::ExprTemplate, model) = JuMP.check_belongs_to_model(f.expr, model)
33+
2834
"""
2935
struct IteratorValues{I}
3036
iterator::Iterator
@@ -59,7 +65,7 @@ iterator([5, -5])
5965
"""
6066
struct IteratorValues{I}
6167
iterators::Iterators
62-
index::Int
68+
index::IteratorIndex
6369
values::I
6470
end
6571

@@ -71,50 +77,66 @@ function Base.show(io::IO, i::IteratorValues)
7177
end
7278

7379
function iterators(axes)
74-
iterators = [Iterator(length(axe)) for axe in axes]
75-
return IteratorValues.(Ref(iterators), eachindex(axes), axes)
80+
iterators = Iterator[Iterator(axe) for axe in axes]
81+
return IteratorValues.(Ref(iterators), IteratorIndex.(eachindex(axes)), axes)
7682
end
7783

7884
iterator(axe) = iterators([axe])[]
7985

8086
function Base.getindex(d::Dict, i::IteratorValues)
81-
return IteratorValues(i.iterators, i.index, [d[val] for val in i.values])
87+
new_values = [d[val] for val in i.values]
88+
i.iterators[i.index.value] = Iterator(new_values)
89+
return IteratorValues(i.iterators, i.index, new_values)
8290
end
8391

8492
# The following is intentionally kept close to JuMP/src/nlp_expr.jl
8593
const _ScalarWithIterator = Union{ExprTemplate,IteratorValues}
8694

95+
function _univariate(f, op, x)
96+
V = something(
97+
_variable_ref_type(x),
98+
JuMP.VariableRef, # FIXME needed if `x` is an iterator
99+
)
100+
nl = JuMP.GenericNonlinearExpr{V}(op, _expr(x))
101+
E = JuMP._MA.promote_operation(f, _type(x))
102+
return ExprTemplate{E}(nl, _iterators(x))
103+
end
104+
87105
# Univariate operators
88106
for f in MOI.Nonlinear.DEFAULT_UNIVARIATE_OPERATORS
107+
op = Meta.quot(f)
89108
if isdefined(Base, f)
90109
@eval function Base.$(f)(x::IteratorValues)
91-
return IteratorValues(x.iterators, x.index, $(f).(x.values))
110+
return _univariate($f, $op, x)
92111
end
93112
@eval function Base.$(f)(x::ExprTemplate)
94-
return ExprTemplate($f(x.expr), x.iterators)
113+
return _univariate($f, $op, x)
95114
end
96115
end
97116
end
98117

99118
function prepare(it::IteratorValues)
100-
append!(it.iterators[it.index].values, it.values)
101-
index = IteratorIndex(it.index, num_values(it.iterators[it.index]))
102-
return IteratorInExpr(it.iterators, index)
119+
@assert it.values == it.iterators[it.index.value].values
120+
return IteratorInExpr(it.iterators, it.index)
103121
end
104122

105123
_expr(f::JuMP.AbstractJuMPScalar) = f
106124
_expr(it::IteratorValues) = prepare(it)
107125
_expr(f::ExprTemplate) = f.expr
126+
_expr(f::Number) = f
108127

109128
_variable_ref_type(f::JuMP.AbstractJuMPScalar) = JuMP.variable_ref_type(f)
110129
_variable_ref_type(::IteratorValues) = nothing
130+
_variable_ref_type(::Number) = nothing
111131

112132
_iterators(::JuMP.AbstractJuMPScalar) = nothing
113133
_iterators(it::_ScalarWithIterator) = it.iterators
134+
_iterators(::Number) = nothing
114135

115136
_type(it::_ScalarWithIterator) = eltype(it.values)
116137
_type(::ExprTemplate{E}) where {E} = E
117138
_type(f::JuMP.AbstractJuMPScalar) = typeof(f)
139+
_type(f::Number) = typeof(f)
118140

119141
_check_equal(it::Iterators, ::Nothing) = it
120142
_check_equal(::Nothing, it::Iterators) = it
@@ -139,10 +161,16 @@ for f in [:+, :-, :*, :^, :/, :atan, :min, :max]
139161
op = Meta.quot(f)
140162
@eval begin
141163
function Base.$(f)(x::IteratorValues, y::Number)
142-
return IteratorValues(x.iterators, x.index, $(f).(x.values, y))
164+
return _multivariate($f, $op, x, y)
143165
end
144166
function Base.$(f)(x::Number, y::IteratorValues)
145-
return IteratorValues(y.iterators, y.index, $(f).(x, y.values))
167+
return _multivariate($f, $op, x, y)
168+
end
169+
function Base.$(f)(x::ExprTemplate, y::Number)
170+
return _multivariate($f, $op, x, y)
171+
end
172+
function Base.$(f)(x::Number, y::ExprTemplate)
173+
return _multivariate($f, $op, x, y)
146174
end
147175
function Base.$(f)(x::_ScalarWithIterator, y::JuMP.AbstractJuMPScalar)
148176
return _multivariate($f, $op, x, y)

src/print.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
function Base.show(io::IO, i::IteratorInExpr)
2+
print(io, i.iterators[i.index.value].values)
3+
print(io, "[i]")
4+
return
5+
end
6+
17
function Base.show(io::IO, f::Union{ExprGenerator,ExprTemplate})
28
return print(io, JuMP.function_string(MIME("text/plain"), f))
39
end

test/JuMP_wrapper.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,15 @@ function test_container()
4040
@test con isa IteratedConstraint
4141
con_it_expr = jump_function(con)
4242
@test con_it_expr isa ExprGenerator
43-
con_expr = con_it_expr.expr
43+
con_expr = con_it_expr.expr.expr
4444
@test sprint(show, con_ref) ==
45-
"ParametrizedArray(((x + (IteratorIndex(1, 1))) - (IteratorIndex(2, 2))) - 0.0, Iterator(2, [-1.0, 1.0, 3.141592653589793, 0.0]), Iterator(2, [-1.0, 1.0, 3.141592653589793, 0.0]) ∈ MathOptInterface.Nonnegatives(4), (iterator([:a, :b]),))"
45+
"GenOpt.ParametrizedArray(((x + (Real[π, 0.0][i])) - (Real[π, 0.0][i])) - 0.0, GenOpt.Iterator{Real}(Real[π, 0.0]) ∈ MathOptInterface.Nonnegatives(2), GenOpt.IteratorValues{Vector{Symbol}}[iterator([:a, :b])])"
4646
@test sprint(show, MIME"text/latex"(), con_ref) ==
47-
"ParametrizedArray(((x + (IteratorIndex(1, 1))) - (IteratorIndex(2, 2))) - 0.0, Iterator(2, [-1.0, 1.0, 3.141592653589793, 0.0]), Iterator(2, [-1.0, 1.0, 3.141592653589793, 0.0]) ∈ MathOptInterface.Nonnegatives(4), (iterator([:a, :b]),))"
47+
"GenOpt.ParametrizedArray(((x + (Real[π, 0.0][i])) - (Real[π, 0.0][i])) - 0.0, GenOpt.Iterator{Real}(Real[π, 0.0]) ∈ MathOptInterface.Nonnegatives(2), GenOpt.IteratorValues{Vector{Symbol}}[iterator([:a, :b])])"
4848
@test sprint(show, con_expr) ==
49-
"((x + (IteratorIndex(1, 1))) - (IteratorIndex(2, 2))) - 0.0"
49+
"((x + (Real[π, 0.0][i])) - (Real[π, 0.0][i])) - 0.0"
5050
@test sprint(show, MIME"text/latex"(), con_expr) ==
51-
"\$ {\\left({\\left({x} + {\\left(IteratorIndex(1, 1)\\right)}\\right)} - {\\left(IteratorIndex(2, 2)\\right)}\\right)} - {0.0} \$"
51+
"\$ {\\left({\\left({x} + {\\left(Real[π, 0.0][i]\\right)}\\right)} - {\\left(Real[π, 0.0][i]\\right)}\\right)} - {0.0} \$"
5252

5353
i = GenOpt.iterator(keys)
5454
expr = x + d1[i] - d2[i]

test/operators.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,30 +22,39 @@ end
2222

2323
function _test_iterator(it, values)
2424
@test it isa IteratorValues
25-
@test it.iterators[1].length == length(values)
25+
@test it.iterators[1].values == values
2626
@test it.values == values
2727
end
2828

29+
function _test_template(et, values)
30+
@test et isa ExprTemplate
31+
for i in eachindex(values)
32+
@test index_iterators(et.expr, (i,)) == values[i]
33+
end
34+
end
35+
2936
function test_getindex()
3037
d1 = Dict(:a => -1, :b => 1)
3138
d2 = Dict(:a => π, :b => 0.0)
3239

3340
i = GenOpt.iterator([:a, :b])
3441

3542
_test_iterator(d1[i], [-1, 1])
36-
return _test_iterator(d2[i], Real[π, 0.0])
43+
_test_iterator(d2[i], Real[π, 0.0])
44+
return
3745
end
3846

3947
function test_univariate()
4048
i = GenOpt.iterator([2, -3])
41-
_test_iterator(+i, [2, -3])
42-
return _test_iterator(-i, [-2, 3])
49+
_test_template(+i, [2, -3])
50+
_test_template(-i, [-2, 3])
51+
return
4352
end
4453

4554
function test_multivariate()
4655
i, j = GenOpt.iterators(([2, -3], [1, -1]))
47-
_test_iterator(i + 1, [3, -2])
48-
_test_iterator(2 - i, [0, 5])
56+
_test_template(i + 1, [3, -2])
57+
_test_template(2 - i, [0, 5])
4958
ij = i + j
5059
@test ij isa GenOpt.ExprTemplate{Int,JuMP.VariableRef}
5160
model = JuMP.Model()

0 commit comments

Comments
 (0)