Skip to content

Commit fa69b01

Browse files
authored
Map parameter functions to JuMP.Parameters (#389)
* Map parameter functions to JuMP params in infvar mapping dicts * Add parameter funcs to infinite var transcription methods + update relevant unit tests * Add docstring to transcribe_parameter_functions! + include in technical manual * Fix typo in transcribe_parameter_functions! docstring * Add unit test to test value func with real value argument
1 parent 2bd0383 commit fa69b01

File tree

7 files changed

+132
-64
lines changed

7 files changed

+132
-64
lines changed

docs/src/manual/transcribe.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ InfiniteOpt.TranscriptionOpt.set_parameter_supports
1111
InfiniteOpt.TranscriptionOpt.transcribe_finite_parameters!
1212
InfiniteOpt.TranscriptionOpt.transcribe_finite_variables!
1313
InfiniteOpt.TranscriptionOpt.transcribe_infinite_variables!
14+
InfiniteOpt.TranscriptionOpt.transcribe_parameter_functions!
1415
InfiniteOpt.TranscriptionOpt.transcribe_derivative_variables!
1516
InfiniteOpt.TranscriptionOpt.transcribe_semi_infinite_variables!
1617
InfiniteOpt.TranscriptionOpt.transcribe_point_variables!

src/TranscriptionOpt/model.jl

Lines changed: 4 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -351,13 +351,13 @@ function transcription_variable(
351351
return var
352352
end
353353

354-
# InfVarIndex
354+
# InfVarIndex & ParameterFunctionIndex
355355
function transcription_variable(
356356
vref::InfiniteOpt.GeneralVariableRef,
357357
::Type{V},
358358
backend::TranscriptionBackend,
359359
label::Type{<:InfiniteOpt.AbstractSupportLabel}
360-
) where {V <: InfVarIndex}
360+
) where {V <: Union{InfVarIndex, InfiniteOpt.ParameterFunctionIndex}}
361361
vars = get(transcription_data(backend).infvar_mappings, vref, nothing)
362362
if isnothing(vars)
363363
error("Variable reference $vref not used in transcription backend.")
@@ -370,32 +370,6 @@ function transcription_variable(
370370
end
371371
end
372372

373-
# ParameterFunctionIndex
374-
function transcription_variable(
375-
fref::InfiniteOpt.GeneralVariableRef,
376-
::Type{InfiniteOpt.ParameterFunctionIndex},
377-
backend::TranscriptionBackend,
378-
label::Type{<:InfiniteOpt.AbstractSupportLabel}
379-
)
380-
# get the parameter group integer indices of the expression and form the support iterator
381-
group_idxs = InfiniteOpt.parameter_group_int_indices(fref)
382-
support_indices = support_index_iterator(backend, group_idxs)
383-
dims = size(support_indices)[group_idxs]
384-
vals = Array{Float64, length(dims)}(undef, dims...)
385-
# iterate over the indices and compute the values
386-
for idx in support_indices
387-
supp = index_to_support(backend, idx)
388-
val_idx = idx.I[group_idxs]
389-
@inbounds vals[val_idx...] = transcription_expression(fref, backend, supp)
390-
end
391-
# return the values
392-
if _ignore_label(backend, label)
393-
return vals
394-
else
395-
return _truncate_by_label(vals, fref, label, group_idxs, backend)
396-
end
397-
end
398-
399373
# Fallback
400374
function transcription_variable(
401375
vref::InfiniteOpt.GeneralVariableRef,
@@ -521,30 +495,19 @@ _supp_error() = error("""
521495
parameters.
522496
""")
523497

524-
# InfiniteIndex
498+
# InfiniteIndex & ParameterFunctionIndex
525499
function lookup_by_support(
526500
vref::InfiniteOpt.GeneralVariableRef,
527501
::Type{V},
528502
backend::TranscriptionBackend,
529503
support::Vector
530-
) where {V <: InfVarIndex}
504+
) where {V <: Union{InfVarIndex, InfiniteOpt.ParameterFunctionIndex}}
531505
if !haskey(transcription_data(backend).infvar_lookup, vref)
532506
error("Variable reference $vref not used in transcription backend.")
533507
end
534508
return get(_supp_error, transcription_data(backend).infvar_lookup[vref], support)
535509
end
536510

537-
# ParameterFunctionIndex
538-
function lookup_by_support(
539-
fref::InfiniteOpt.GeneralVariableRef,
540-
::Type{InfiniteOpt.ParameterFunctionIndex},
541-
backend::TranscriptionBackend,
542-
support::Vector
543-
)
544-
prefs = InfiniteOpt.raw_parameter_refs(fref)
545-
return InfiniteOpt.call_function(fref, Tuple(support, prefs)...)
546-
end
547-
548511
# FiniteIndex
549512
function lookup_by_support(
550513
vref::InfiniteOpt.GeneralVariableRef,

src/TranscriptionOpt/transcribe.jl

Lines changed: 59 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,62 @@ function _make_var_name(base_name, param_nums, tuple_supp, var_idx)
136136
end
137137
end
138138

139+
"""
140+
transcribe_parameter_functions!(
141+
backend::TranscriptionBackend,
142+
model::InfiniteOpt.InfiniteModel
143+
)::Nothing
144+
145+
Create transcription variables (i.e., JuMP Parameters) corresponding to
146+
all supports of each `Parameter Function` stored in `model` and add them to
147+
`backend`. The variable mappings are also stored in
148+
`TranscriptionData.infvar_mappings` in accordance with
149+
`TranscriptionData.infvar_lookup` which enable [`transcription_variable`](@ref)
150+
and [`lookup_by_support`](@ref). Note that the supports will not be generated
151+
until `InfiniteOpt.variable_supports` is invoked via `InfiniteOpt.supports`.
152+
Note that `TranscriptionData.infvar_support_labels` is also populated.
153+
"""
154+
function transcribe_parameter_functions!(
155+
backend::TranscriptionBackend,
156+
model::InfiniteOpt.InfiniteModel,
157+
)
158+
for (idx, object) in model.param_functions
159+
# get the basic parameter function information
160+
pf = object.func
161+
base_name = object.name
162+
func = pf.func
163+
param_nums = pf.parameter_nums
164+
group_idxs = pf.group_int_idxs
165+
prefs = pf.parameter_refs
166+
# prepare for iterating over its supports
167+
supp_indices = support_index_iterator(backend, group_idxs)
168+
dims = size(supp_indices)[group_idxs]
169+
vrefs = Array{JuMP.VariableRef, length(dims)}(undef, dims...)
170+
supp_type = typeof(Tuple(ones(length(prefs)), prefs))
171+
supps = Array{supp_type, length(dims)}(undef, dims...)
172+
lookup_dict = sizehint!(Dict{Vector{Float64}, JuMP.VariableRef}(), length(vrefs))
173+
# Create a parameter for each support
174+
for i in supp_indices
175+
supp = index_to_support(backend, i)[param_nums]
176+
var_idx = i.I[group_idxs]
177+
tuple_supp = Tuple(supp, prefs)
178+
p_name = _make_var_name(base_name, param_nums, tuple_supp, var_idx)
179+
pValue = func(tuple_supp...)
180+
jump_pref = JuMP.@variable(backend.model, base_name = p_name, set = MOI.Parameter(pValue))
181+
vrefs[var_idx...] = jump_pref
182+
lookup_dict[supp] = jump_pref
183+
supps[var_idx...] = tuple_supp
184+
end
185+
# save the transcription information
186+
pfref = InfiniteOpt.GeneralVariableRef(model, idx)
187+
data = transcription_data(backend)
188+
data.infvar_lookup[pfref] = lookup_dict
189+
data.infvar_mappings[pfref] = vrefs
190+
data.infvar_supports[pfref] = supps
191+
end
192+
return
193+
end
194+
139195
"""
140196
transcribe_infinite_variables!(
141197
backend::TranscriptionBackend,
@@ -475,16 +531,8 @@ function transcription_expression(
475531
backend::TranscriptionBackend,
476532
support::Vector{Float64}
477533
)
478-
ivref = InfiniteOpt.infinite_variable_ref(vref)
479-
if InfiniteOpt._index_type(ivref) == InfiniteOpt.ParameterFunctionIndex
480-
prefs = InfiniteOpt.raw_parameter_refs(ivref)
481-
param_nums = InfiniteOpt._parameter_numbers(ivref)
482-
func = InfiniteOpt.raw_function(ivref)
483-
return func(Tuple(support[param_nums], prefs)...)
484-
else
485-
param_nums = InfiniteOpt._parameter_numbers(vref)
486-
return lookup_by_support(vref, index_type, backend, support[param_nums])
487-
end
534+
param_nums = InfiniteOpt._parameter_numbers(vref)
535+
return lookup_by_support(vref, index_type, backend, support[param_nums])
488536
end
489537

490538
# Point variables, finite variables and finite parameters
@@ -979,6 +1027,7 @@ function build_transcription_backend!(
9791027
transcribe_finite_parameters!(backend, model)
9801028
transcribe_finite_variables!(backend, model)
9811029
transcribe_infinite_variables!(backend, model)
1030+
transcribe_parameter_functions!(backend, model)
9821031
transcribe_derivative_variables!(backend, model)
9831032
transcribe_semi_infinite_variables!(backend, model)
9841033
transcribe_point_variables!(backend, model)

test/TranscriptionOpt/model.jl

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ end
121121
@variable(tb.model, d)
122122
@variable(tb.model, e)
123123
@variable(tb.model, f)
124+
@variable(tb.model, p1a in Parameter(sin(0)))
125+
@variable(tb.model, p1b in Parameter(sin(0.5)))
126+
@variable(tb.model, p1c in Parameter(sin(1)))
127+
@variable(tb.model, p2a in Parameter(1))
128+
@variable(tb.model, p2b in Parameter(1))
129+
@variable(tb.model, p2c in Parameter(1))
130+
@variable(tb.model, p2d in Parameter(1))
131+
@variable(tb.model, p2e in Parameter(1))
132+
@variable(tb.model, p2f in Parameter(1))
124133
# test _ignore_label
125134
@testset "_ignore_label" begin
126135
@test IOTO._ignore_label(tb, All)
@@ -156,10 +165,20 @@ end
156165
end
157166
# test IOTO.transcription_variable (Parameter Function)
158167
@testset "IOTO.transcription_variable (Parameter Function)" begin
168+
# test error
169+
@test_throws ErrorException IOTO.transcription_variable(f1, tb)
170+
@test_throws ErrorException IOTO.transcription_variable(f2, tb)
159171
# test normal
160-
@test IOTO.transcription_variable(f1, tb) == [sin(0), sin(1)]
161-
@test IOTO.transcription_variable(f1, tb, label = All) == sin.([0, 0.5, 1])
162-
@test IOTO.transcription_variable(f2, tb) == ones(2, 2)
172+
data.infvar_mappings[f1] = [p1a, p1b, p1c]
173+
data.infvar_mappings[f2] = [p2a p2b; p2c p2d; p2e p2f]
174+
@test JuMP.parameter_value.([p1a, p1b, p1c]) == sin.([0., 0.5, 1.])
175+
@test JuMP.parameter_value.([p2a p2b; p2c p2d; p2e p2f]) == ones(3, 2)
176+
@test IOTO.transcription_variable(f1, tb) isa Vector{JuMP.VariableRef}
177+
@test IOTO.transcription_variable(f1, tb) == [p1a, p1c]
178+
@test IOTO.transcription_variable(f1, tb, label = All) == [p1a, p1b, p1c]
179+
@test IOTO.transcription_variable(f2, tb) isa Matrix{JuMP.VariableRef}
180+
@test IOTO.transcription_variable(f2, tb) == [p2a p2b; p2e p2f]
181+
@test IOTO.transcription_variable(f2, tb, label = All) == [p2a p2b; p2c p2d; p2e p2f]
163182
end
164183
# test IOTO.transcription_variable (Fallback)
165184
@testset "IOTO.transcription_variable (Fallback)" begin
@@ -170,7 +189,7 @@ end
170189
@test IOTO.transcription_variable(y) == a
171190
@test IOTO.transcription_variable(x, label = All) == [a b; c d; e f]
172191
@test IOTO.transcription_variable(x0) == a
173-
@test IOTO.transcription_variable(f2) == ones(2, 2)
192+
@test IOTO.transcription_variable(f2) == [p2a p2b; p2e p2f]
174193
end
175194
# test transformation_variable extension
176195
@testset "transformation_variable" begin
@@ -247,8 +266,30 @@ end
247266
end
248267
# test lookup_by_support (infinite parameter functions)
249268
@testset "lookup_by_support (Parameter Function)" begin
250-
@test IOTO.lookup_by_support(f1, tb, [0.]) == 0
251-
@test IOTO.lookup_by_support(f2, tb, [0., 0., 1.]) == 1
269+
lookups = Dict{Vector{Float64}, VariableRef}(
270+
[0.] => p1a,
271+
[0.5] => p1b,
272+
[1.] => p1c
273+
)
274+
data.infvar_lookup[f1] = lookups
275+
lookups = Dict{Vector{Float64}, VariableRef}(
276+
[0., 0., 0.] => p2a,
277+
[0.5, 0., 0.] => p2b,
278+
[1., 0., 0.] => p2c,
279+
[0., 1., 1.] => p2d,
280+
[0.5, 1., 1.] => p2e,
281+
[1., 1., 1.] => p2f
282+
)
283+
data.infvar_lookup[f2] = lookups
284+
# test errors
285+
@test_throws ErrorException IOTO.lookup_by_support(f1, tb, [0.75])
286+
@test_throws ErrorException IOTO.lookup_by_support(f2, tb, [0., 0., 1.])
287+
# test normal
288+
@test IOTO.lookup_by_support(f1, tb, [0.]) == p1a
289+
@test IOTO.lookup_by_support(f1, tb, [0.5]) == p1b
290+
@test IOTO.lookup_by_support(f1, tb, [1.]) == p1c
291+
@test IOTO.lookup_by_support(f2, tb, [0.5, 0., 0.]) == p2b
292+
@test IOTO.lookup_by_support(f2, tb, [0.5, 1., 1.]) == p2e
252293
end
253294
# test lookup_by_support (finite vars)
254295
@testset "lookup_by_support (Finite)" begin
@@ -405,15 +446,19 @@ end
405446
@variable(tb.model, c)
406447
@variable(tb.model, d)
407448
@variable(tb.model, e in Parameter(42))
449+
@variable(tb.model, pf1 in Parameter(1))
450+
@variable(tb.model, pf2 in Parameter(1))
408451
# transcribe the variables and measures
409452
data = IOTO.transcription_data(tb)
410453
data.finvar_mappings[y] = a
411454
data.finvar_mappings[x0] = a
412455
data.finvar_mappings[finpar] = e
456+
data.infvar_mappings[f] = [pf1, pf2]
413457
data.infvar_mappings[x] = reshape([a, b], :, 1)
414458
data.measure_mappings[meas1] = fill(-2 * zero(AffExpr))
415459
data.measure_mappings[meas2] = [a^2 + c^2 - 2a, b^2 + d^2 - 2a]
416460
data.infvar_lookup[x] = Dict([0, 0, 0] => a, [1, 0, 0] => b)
461+
data.infvar_lookup[f] = Dict([0, 0, 0] => pf1, [1, 0, 0] => pf2)
417462
data.measure_lookup[meas1] = Dict(Float64[] => 1)
418463
data.measure_lookup[meas2] = Dict([0] => 1, [1] => 2)
419464
@test IOTO.set_parameter_supports(tb, m) isa Nothing
@@ -425,14 +470,16 @@ end
425470
@testset "IOTO.transcription_expression (Infinite Variable)" begin
426471
@test IOTO.transcription_expression(x, tb, [0., 0., 0.]) == a
427472
@test IOTO.transcription_expression(meas1, tb, [0., 0., 1.]) == -2 * zero(AffExpr)
428-
@test IOTO.transcription_expression(f, tb, [0., 0., 1.]) == 1
473+
@test IOTO.transcription_expression(f, tb, [0., 0., 1.]) == pf2
429474
end
430475
# test transcription expression for semi_infinite variables with 3 args
431476
@testset "IOTO.transcription_expression (Semi-Infinite Variable)" begin
432477
# semi_infinite of parameter function
433478
rv = add_variable(m, build_variable(error, f, Dict(1=>1.)),
434479
add_support = false)
435-
@test IOTO.transcription_expression(rv, tb, [0., 0., 1.]) == 1
480+
data.infvar_mappings[rv] = [pf2]
481+
data.infvar_lookup[rv] = Dict([0, 0] => pf2)
482+
@test IOTO.transcription_expression(rv, tb, [0., 0., 1.]) == pf2
436483
# semi_infinite of infinite variable
437484
rv = add_variable(m, build_variable(error, x, Dict(1=>1.)),
438485
add_support = false)

test/TranscriptionOpt/transcribe.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,8 +532,9 @@ end
532532
@test objective_sense(tb.model) == MOI.MIN_SENSE
533533
# test constraints
534534
yt = IOTO.transcription_variable(y)
535+
ft = IOTO.transformation_variable(f)
535536
dt_c1 = IOTO.lookup_by_support(d1, tb, zeros(3))
536-
@test constraint_object(IOTO.transcription_constraint(c1)).func == -zt + xt[1] + dt_c1
537+
@test constraint_object(IOTO.transcription_constraint(c1)).func == -zt + xt[1] + dt_c1 + ft[1]
537538
@test constraint_object(IOTO.transcription_constraint(c2)).func == zt + xt[1]
538539
expected = IOTO.transcription_variable(meas2)[2] - 2 * IOTO.transcription_variable(y0) + xt[2] + IOTO.transcription_variable(fin)
539540
@test constraint_object(IOTO.transcription_constraint(c4)).func == expected

test/backend_mappings.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,16 @@ end
5959
build_transformation_backend!(m)
6060
tb = m.backend
6161
tdata = IOTO.transcription_data(tb)
62+
ft = transformation_variable(f, tb)
6263
# Test transformation_variable
6364
@testset "transformation_variable" begin
6465
# test normal usage
6566
@test transformation_variable(x, label = All) == IOTO.transcription_variable(x, label = All)
6667
@test transformation_variable(x0) == IOTO.transcription_variable(x0)
6768
@test transformation_variable(z) == IOTO.transcription_variable(z)
6869
@test transformation_variable(d1, label = InternalLabel) == IOTO.transcription_variable(d1, label = InternalLabel)
69-
@test transformation_variable(f) == [0, sin(1)]
70+
@test ft isa Array{JuMP.VariableRef}
71+
@test JuMP.parameter_value.(ft) == [0, sin(1)]
7072
# test deprecation
7173
@test (@test_deprecated optimizer_model_variable(z)) == transformation_variable(z)
7274
# test fallback
@@ -104,7 +106,7 @@ end
104106
@test transformation_expression(x^2 + z) == [xt[1]^2 + zt, xt[3]^2 + zt]
105107
@test transformation_expression(x^2 + z, label = All) == [xt[1]^2 + zt, xt[2]^2 + zt, xt[3]^2 + zt]
106108
@test transformation_expression(2z - 3) == 2zt - 3
107-
@test transformation_expression(2 * f) == [zero(AffExpr), zero(AffExpr) + sin(1) * 2]
109+
@test transformation_expression(2 * f) == [2 * ft[1], 2* ft[2]]
108110
# test deprecation
109111
@test (@test_deprecated optimizer_model_expression(2z-4)) == 2zt - 4
110112
# test fallback

test/results.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ end
190190
inf2t = transformation_variable(inf2, label = All)
191191
d1t = transformation_variable(d1, label = All)
192192
rvt = transformation_variable(rv, label = All)
193+
at = transformation_variable(a, label = All)
193194
cref = UpperBoundRef(g)
194195
creft = transformation_constraint(cref, label = All)
195196
# setup the optimizer
@@ -213,6 +214,9 @@ end
213214
MOI.set(mockoptimizer, MOI.VariablePrimal(1), JuMP.optimizer_index(d1t[1]), 2.0)
214215
MOI.set(mockoptimizer, MOI.VariablePrimal(1), JuMP.optimizer_index(d1t[2]), 1.0)
215216
MOI.set(mockoptimizer, MOI.VariablePrimal(1), JuMP.optimizer_index(d1t[3]), 2.0)
217+
MOI.set(mockoptimizer, MOI.VariablePrimal(1), JuMP.optimizer_index(at[1]), sin(0.0))
218+
MOI.set(mockoptimizer, MOI.VariablePrimal(1), JuMP.optimizer_index(at[2]), sin(0.5))
219+
MOI.set(mockoptimizer, MOI.VariablePrimal(1), JuMP.optimizer_index(at[3]), sin(1.0))
216220
# test has_values
217221
@testset "JuMP.has_values" begin
218222
@test has_values(m)
@@ -242,8 +246,9 @@ end
242246
@test value(par, label = All) == [0., 0.5, 1.]
243247
@test value(fin) == 42
244248
@test value(fin, label = All) == 42
245-
@test value(a) == [sin(0.), sin(1.)]
246-
@test value(a, label = All) == [sin(0.), sin(0.5), sin(1.)]
249+
@test value(a) == sin.([0., 1.])
250+
@test value(a, label = All) == sin.([0., 0.5, 1.])
251+
@test value(sin(0.0)) == sin(0.0)
247252
end
248253
#test Reduced Cost
249254
@testset "map_reduced_cost" begin

0 commit comments

Comments
 (0)