1
1
module ApplicationDrivenLearning
2
2
3
-
4
3
using Flux
5
4
using JuMP
6
5
using DiffOpt
7
6
import ParametricOptInterface as POI
8
7
import Base.* , Base.+
9
8
10
-
11
9
include (" flux_utils.jl" )
12
10
include (" predictive_model.jl" )
13
11
@@ -22,7 +20,7 @@ struct Policy{T}
22
20
end
23
21
24
22
+ (p1:: Policy , p2:: Policy ) = Policy (p1. plan + p2. plan, p1. assess + p2. assess)
25
- * (c:: Number , p:: Policy ) = Policy (c* p. plan, c* p. assess)
23
+ * (c:: Number , p:: Policy ) = Policy (c * p. plan, c * p. assess)
26
24
27
25
"""
28
26
Forecast{T}
@@ -34,19 +32,21 @@ struct Forecast{T}
34
32
assess:: T
35
33
end
36
34
37
- + (p1:: Forecast , p2:: Forecast ) = Forecast (p1. plan + p2. plan, p1. assess + p2. assess)
38
- * (c:: Number , p:: Forecast ) = Forecast (c* p. plan, c* p. assess)
35
+ function + (p1:: Forecast , p2:: Forecast )
36
+ return Forecast (p1. plan + p2. plan, p1. assess + p2. assess)
37
+ end
38
+ * (c:: Number , p:: Forecast ) = Forecast (c * p. plan, c * p. assess)
39
39
40
40
"""
41
41
Model <: JuMP.AbstractModel
42
42
43
- Create an empty ApplicationDrivenLearning.Model with empty plan and assess models,
44
- missing forecast model and default settings.
43
+ Create an empty ApplicationDrivenLearning.Model with empty plan and assess
44
+ models, missing forecast model and default settings.
45
45
"""
46
46
mutable struct Model <: JuMP.AbstractModel
47
47
plan:: JuMP.Model
48
48
assess:: JuMP.Model
49
- forecast:: Union{PredictiveModel, Nothing}
49
+ forecast:: Union{PredictiveModel,Nothing}
50
50
51
51
# variable arrays
52
52
policy_vars:: Vector{Policy}
@@ -62,9 +62,14 @@ mutable struct Model <: JuMP.AbstractModel
62
62
assess = JuMP. Model ()
63
63
64
64
return new (
65
- plan, assess, nothing ,
66
- Vector {Policy} (), Vector {Forecast} (), Vector {JuMP.VariableRef} (),
67
- Dict {Symbol,Any} (), false
65
+ plan,
66
+ assess,
67
+ nothing ,
68
+ Vector {Policy} (),
69
+ Vector {Forecast} (),
70
+ Vector {JuMP.VariableRef} (),
71
+ Dict {Symbol,Any} (),
72
+ false ,
68
73
)
69
74
end
70
75
end
@@ -102,16 +107,16 @@ Sets Chain, Dense or custom PredictiveModel object as
102
107
forecast model.
103
108
"""
104
109
function set_forecast_model (
105
- model:: Model ,
106
- network:: Union{PredictiveModel, Flux.Chain, Flux.Dense}
110
+ model:: Model ,
111
+ network:: Union{PredictiveModel,Flux.Chain,Flux.Dense} ,
107
112
)
108
113
if typeof (network) == PredictiveModel
109
114
forecast = network
110
115
else
111
116
forecast = PredictiveModel (network)
112
117
end
113
118
@assert forecast. output_size == size (model. forecast_vars, 1 )
114
- model. forecast = forecast
119
+ return model. forecast = forecast
115
120
end
116
121
117
122
"""
@@ -131,13 +136,13 @@ function build_plan_model_forecast_params(model::Model)
131
136
# adds parametrized forecast variables using MOI.Parameter
132
137
forecast_size = size (model. forecast_vars)[1 ]
133
138
model. plan_forecast_params = @variable (
134
- model. plan,
139
+ model. plan,
135
140
_forecast[1 : forecast_size] in MOI. Parameter .(zeros (forecast_size))
136
141
)
137
142
# fixes old and new prediction variables together
138
143
@constraint (
139
- model. plan,
140
- plan_forecast_fix,
144
+ model. plan,
145
+ plan_forecast_fix,
141
146
model. plan_forecast_params .== plan_forecast_vars (model)
142
147
)
143
148
end
@@ -147,8 +152,8 @@ Creates new constraint to assess model that fixes policy variables.
147
152
"""
148
153
function build_assess_model_policy_constraint (model:: Model )
149
154
@constraint (
150
- model. assess,
151
- assess_policy_fix,
155
+ model. assess,
156
+ assess_policy_fix,
152
157
assess_policy_vars (model) .== 0
153
158
)
154
159
end
@@ -165,7 +170,7 @@ function build(model::Model)
165
170
166
171
# build plan model
167
172
build_plan_model_forecast_params (model)
168
- build_assess_model_policy_constraint (model)
173
+ return build_assess_model_policy_constraint (model)
169
174
end
170
175
171
176
include (" jump.jl" )
@@ -184,9 +189,9 @@ include("optimizers/bilevel.jl")
184
189
Train model using given data and options.
185
190
"""
186
191
function train! (
187
- model:: Model ,
188
- X:: Matrix{<:Real} ,
189
- y:: Matrix{<:Real} ,
192
+ model:: Model ,
193
+ X:: Matrix{<:Real} ,
194
+ y:: Matrix{<:Real} ,
190
195
options:: Options ,
191
196
)
192
197
if options. mode == NelderMeadMode
@@ -198,17 +203,23 @@ function train!(
198
203
elseif options. mode == GradientMPIMode
199
204
return train_with_gradient_mpi! (model, X, y, options. params)
200
205
elseif options. mode == BilevelMode
201
- assert_msg = " BilevelMode not implemented for multiple forecasting models"
202
- @assert length (model. forecast. networks) == 1 assert_msg
206
+ asr_msg = " BilevelMode not implemented for multiple forecasting models"
207
+ @assert length (model. forecast. networks) == 1 asr_msg
203
208
return solve_bilevel (model, X, y, options. params)
204
209
else
205
210
# should never get here
206
211
throw (ArgumentError (" Invalid optimization method" ))
207
212
end
208
213
end
209
214
210
- export Model, PredictiveModel, Plan, Assess,
211
- Policy, Forecast,
212
- set_forecast_model, forecast,
213
- compute_cost, train!
215
+ export Model,
216
+ PredictiveModel,
217
+ Plan,
218
+ Assess,
219
+ Policy,
220
+ Forecast,
221
+ set_forecast_model,
222
+ forecast,
223
+ compute_cost,
224
+ train!
214
225
end
0 commit comments