@@ -10,6 +10,10 @@ const Iterators = Vector{Iterator}
1010 expr::JuMP.GenericNonlinearExpr{V}
1111 iterators::Vector{Iterator}
1212 end
13+
14+ Represent a `JuMP.GenericNonlinearExpr` containing iterators.
15+ Thanks to this custom type, we can define a custom `JuMP.build_constraint` method
16+ to generate constraint of different types.
1317"""
1418struct ExprTemplate{E,V<: JuMP.AbstractVariableRef } <: JuMP.AbstractJuMPScalar
1519 expr:: JuMP.GenericNonlinearExpr{V}
2529
2630JuMP. variable_ref_type(:: Type{ExprTemplate{E,V}} ) where {E,V} = V
2731
32+ JuMP. check_belongs_to_model(f:: ExprTemplate , model) = JuMP. check_belongs_to_model(f. expr, model)
33+
2834"""
2935 struct IteratorValues{I}
3036 iterator::Iterator
@@ -59,7 +65,7 @@ iterator([5, -5])
5965"""
6066struct IteratorValues{I}
6167 iterators:: Iterators
62- index:: Int
68+ index:: IteratorIndex
6369 values:: I
6470end
6571
@@ -71,50 +77,66 @@ function Base.show(io::IO, i::IteratorValues)
7177end
7278
7379function iterators(axes)
74- iterators = [Iterator(length( axe) ) for axe in axes]
75- return IteratorValues.(Ref(iterators), eachindex(axes), axes)
80+ iterators = Iterator [Iterator(axe) for axe in axes]
81+ return IteratorValues.(Ref(iterators), IteratorIndex.( eachindex(axes) ), axes)
7682end
7783
7884iterator(axe) = iterators([axe])[]
7985
8086function Base. getindex(d:: Dict , i:: IteratorValues )
81- return IteratorValues(i. iterators, i. index, [d[val] for val in i. values])
87+ new_values = [d[val] for val in i. values]
88+ i. iterators[i. index. value] = Iterator(new_values)
89+ return IteratorValues(i. iterators, i. index, new_values)
8290end
8391
8492# The following is intentionally kept close to JuMP/src/nlp_expr.jl
8593const _ScalarWithIterator = Union{ExprTemplate,IteratorValues}
8694
95+ function _univariate(f, op, x)
96+ V = something(
97+ _variable_ref_type(x),
98+ JuMP. VariableRef, # FIXME needed if `x` is an iterator
99+ )
100+ nl = JuMP. GenericNonlinearExpr{V}(op, _expr(x))
101+ E = JuMP. _MA. promote_operation(f, _type(x))
102+ return ExprTemplate{E}(nl, _iterators(x))
103+ end
104+
87105# Univariate operators
88106for f in MOI. Nonlinear. DEFAULT_UNIVARIATE_OPERATORS
107+ op = Meta. quot(f)
89108 if isdefined(Base, f)
90109 @eval function Base.$ (f)(x:: IteratorValues )
91- return IteratorValues(x . iterators, x . index, $ (f) . (x . values) )
110+ return _univariate( $ f, $ op, x )
92111 end
93112 @eval function Base.$ (f)(x:: ExprTemplate )
94- return ExprTemplate ($ f(x . expr), x . iterators )
113+ return _univariate ($ f, $ op, x )
95114 end
96115 end
97116end
98117
99118function prepare(it:: IteratorValues )
100- append!(it. iterators[it. index]. values, it. values)
101- index = IteratorIndex(it. index, num_values(it. iterators[it. index]))
102- return IteratorInExpr(it. iterators, index)
119+ @assert it. values == it. iterators[it. index. value]. values
120+ return IteratorInExpr(it. iterators, it. index)
103121end
104122
105123_expr(f:: JuMP.AbstractJuMPScalar ) = f
106124_expr(it:: IteratorValues ) = prepare(it)
107125_expr(f:: ExprTemplate ) = f. expr
126+ _expr(f:: Number ) = f
108127
109128_variable_ref_type(f:: JuMP.AbstractJuMPScalar ) = JuMP. variable_ref_type(f)
110129_variable_ref_type(:: IteratorValues ) = nothing
130+ _variable_ref_type(:: Number ) = nothing
111131
112132_iterators(:: JuMP.AbstractJuMPScalar ) = nothing
113133_iterators(it:: _ScalarWithIterator ) = it. iterators
134+ _iterators(:: Number ) = nothing
114135
115136_type(it:: _ScalarWithIterator ) = eltype(it. values)
116137_type(:: ExprTemplate{E} ) where {E} = E
117138_type(f:: JuMP.AbstractJuMPScalar ) = typeof(f)
139+ _type(f:: Number ) = typeof(f)
118140
119141_check_equal(it:: Iterators , :: Nothing ) = it
120142_check_equal(:: Nothing , it:: Iterators ) = it
@@ -139,10 +161,16 @@ for f in [:+, :-, :*, :^, :/, :atan, :min, :max]
139161 op = Meta. quot(f)
140162 @eval begin
141163 function Base.$ (f)(x:: IteratorValues , y:: Number )
142- return IteratorValues(x . iterators, x . index, $ (f) . (x . values , y) )
164+ return _multivariate( $ f, $ op, x , y)
143165 end
144166 function Base.$ (f)(x:: Number , y:: IteratorValues )
145- return IteratorValues(y. iterators, y. index, $ (f). (x, y. values))
167+ return _multivariate($ f, $ op, x, y)
168+ end
169+ function Base.$ (f)(x:: ExprTemplate , y:: Number )
170+ return _multivariate($ f, $ op, x, y)
171+ end
172+ function Base.$ (f)(x:: Number , y:: ExprTemplate )
173+ return _multivariate($ f, $ op, x, y)
146174 end
147175 function Base.$ (f)(x:: _ScalarWithIterator , y:: JuMP.AbstractJuMPScalar )
148176 return _multivariate($ f, $ op, x, y)
0 commit comments