@@ -2,7 +2,6 @@ using JuMP
2
2
using Flux
3
3
using BilevelJuMP
4
4
5
-
6
5
function solve_bilevel (
7
6
model:: Model ,
8
7
X:: Matrix{<:Real} ,
@@ -27,10 +26,12 @@ function solve_bilevel(
27
26
T = size (Y, 1 )
28
27
29
28
# lower model variables
30
- low_var_map = Dict {JuMP.VariableRef,Vector{BilevelJuMP.BilevelVariableRef}} ()
29
+ low_var_map =
30
+ Dict {JuMP.VariableRef,Vector{BilevelJuMP.BilevelVariableRef}} ()
31
31
for pre_var in all_variables (model. plan)
32
32
low_var_name = string (name (pre_var), " _low" )
33
- low_var_ref = @variable (Lower (bilevel_model), [1 : T], base_name = low_var_name)
33
+ low_var_ref =
34
+ @variable (Lower (bilevel_model), [1 : T], base_name = low_var_name)
34
35
if has_lower_bound (pre_var)
35
36
set_lower_bound .(low_var_ref, lower_bound (pre_var))
36
37
end
@@ -45,7 +46,8 @@ function solve_bilevel(
45
46
for post_var in all_variables (model. assess)
46
47
if ! (post_var in assess_policy_vars (model))
47
48
up_var_name = string (name (post_var), " _up" )
48
- up_var_ref = @variable (Upper (bilevel_model), [1 : T], base_name = up_var_name)
49
+ up_var_ref =
50
+ @variable (Upper (bilevel_model), [1 : T], base_name = up_var_name)
49
51
if has_lower_bound (post_var)
50
52
set_lower_bound .(up_var_ref, lower_bound (post_var))
51
53
end
@@ -65,20 +67,30 @@ function solve_bilevel(
65
67
end
66
68
67
69
# lower model base constraints
68
- for pre_con in
69
- JuMP. all_constraints (model. plan, include_variable_in_set_constraints = false )
70
+ for pre_con in JuMP. all_constraints (
71
+ model. plan,
72
+ include_variable_in_set_constraints = false ,
73
+ )
70
74
pre_con_func = JuMP. constraint_object (pre_con). func
71
75
lhs = [value (x -> low_var_map[x][t], pre_con_func) for t = 1 : T]
72
- @constraint (Lower (bilevel_model), lhs .∈ JuMP. constraint_object (pre_con). set)
76
+ @constraint (
77
+ Lower (bilevel_model),
78
+ lhs .∈ JuMP. constraint_object (pre_con). set
79
+ )
73
80
end
74
81
75
82
# upper model base constraints
76
- for post_con in
77
- JuMP. all_constraints (model. assess, include_variable_in_set_constraints = false )
83
+ for post_con in JuMP. all_constraints (
84
+ model. assess,
85
+ include_variable_in_set_constraints = false ,
86
+ )
78
87
if name (post_con) != " assess_policy_fix"
79
88
post_con_func = JuMP. constraint_object (post_con). func
80
89
lhs = [value (x -> up_var_map[x][t], post_con_func) for t = 1 : T]
81
- @constraint (Upper (bilevel_model), lhs .∈ JuMP. constraint_object (post_con). set)
90
+ @constraint (
91
+ Upper (bilevel_model),
92
+ lhs .∈ JuMP. constraint_object (post_con). set
93
+ )
82
94
end
83
95
end
84
96
@@ -97,7 +109,10 @@ function solve_bilevel(
97
109
# fix upper model observations
98
110
i_obs_var = 1
99
111
for obs_var in assess_forecast_vars (model)
100
- @constraint (Upper (bilevel_model), up_var_map[obs_var] - Y[1 : T, i_obs_var] .== 0 )
112
+ @constraint (
113
+ Upper (bilevel_model),
114
+ up_var_map[obs_var] - Y[1 : T, i_obs_var] .== 0
115
+ )
101
116
i_obs_var += 1
102
117
end
103
118
@@ -114,7 +129,10 @@ function solve_bilevel(
114
129
if has_params (layer)
115
130
# get size and parameters W and b
116
131
(layer_size_out, layer_size_in) = size (layer. weight)
117
- W = @variable (Upper (bilevel_model), [1 : layer_size_out, 1 : layer_size_in])
132
+ W = @variable (
133
+ Upper (bilevel_model),
134
+ [1 : layer_size_out, 1 : layer_size_in]
135
+ )
118
136
if layer. bias == false
119
137
b = zeros (layer_size_out)
120
138
else
@@ -123,7 +141,8 @@ function solve_bilevel(
123
141
predictive_model_vars[i_layer] = Dict (:W => W, :b => b)
124
142
# build layer output as next layer input
125
143
for output_idx in values (model. forecast. input_output_map[1 ])
126
- layers_inpt[output_idx] = layer. σ (W * layers_inpt[output_idx]' .+ b)'
144
+ layers_inpt[output_idx] =
145
+ layer. σ (W * layers_inpt[output_idx]' .+ b)'
127
146
end
128
147
# if activation function layer, just apply
129
148
elseif supertype (typeof (layer)) == Function
@@ -144,7 +163,10 @@ function solve_bilevel(
144
163
ipred_var_count = 1
145
164
for pred_var in plan_forecast_vars (model)
146
165
low_pred_var = low_var_map[pred_var]
147
- @constraint (Lower (bilevel_model), low_pred_var .- y_hat[:, ipred_var_count] .== 0 )
166
+ @constraint (
167
+ Lower (bilevel_model),
168
+ low_pred_var .- y_hat[:, ipred_var_count] .== 0
169
+ )
148
170
ipred_var_count += 1
149
171
end
150
172
@@ -165,5 +187,8 @@ function solve_bilevel(
165
187
ilayer += 1
166
188
end
167
189
168
- return Solution (objective_value (bilevel_model), extract_params (model. forecast))
190
+ return Solution (
191
+ objective_value (bilevel_model),
192
+ extract_params (model. forecast),
193
+ )
169
194
end
0 commit comments