Skip to content

Commit 0a9d56e

Browse files
committed
apply JuliaFormatter
1 parent dce8a17 commit 0a9d56e

File tree

8 files changed

+121
-69
lines changed

8 files changed

+121
-69
lines changed

src/ApplicationDrivenLearning.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,7 @@ function set_forecast_model(
160160
if forecast.input_output_map == nothing
161161
forecast = PredictiveModel(
162162
deepcopy(forecast.networks),
163-
[
164-
Dict(
165-
collect(1:forecast.input_size) => model.forecast_vars
166-
)
167-
],
163+
[Dict(collect(1:forecast.input_size) => model.forecast_vars)],
168164
model.forecast_vars,
169165
forecast.input_size,
170166
forecast.output_size,
@@ -185,9 +181,8 @@ function forecast(model::Model, X::AbstractMatrix)
185181
# check if input output map is set
186182
if model.forecast.input_output_map === nothing
187183
# set input output map using forecast variables
188-
model.forecast.input_output_map = Dict(
189-
collect(1:model.forecast.input_size) => model.forecast_vars
190-
)
184+
model.forecast.input_output_map =
185+
Dict(collect(1:model.forecast.input_size) => model.forecast_vars)
191186
end
192187

193188
return model.forecast(X)
@@ -251,7 +246,7 @@ Train model using given data and options.
251246
function train!(
252247
model::Model,
253248
X::Matrix{<:Real},
254-
y::Dict{<:Forecast, <:Vector},
249+
y::Dict{<:Forecast,<:Vector},
255250
options::Options,
256251
)
257252
if options.mode == NelderMeadMode

src/optimizers/bilevel.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using BilevelJuMP
55
function solve_bilevel(
66
model::Model,
77
X::Matrix{<:Real},
8-
Y::Dict{<:Forecast, <:Vector},
8+
Y::Dict{<:Forecast,<:Vector},
99
params::Dict{Symbol,Any},
1010
)
1111

@@ -164,10 +164,7 @@ function solve_bilevel(
164164
# and apply prediction on lower model as constraint
165165
for pred_var in model.forecast_vars
166166
low_pred_var = low_var_map[pred_var.plan]
167-
@constraint(
168-
Lower(bilevel_model),
169-
low_pred_var .- y_hat[pred_var] .== 0
170-
)
167+
@constraint(Lower(bilevel_model), low_pred_var .- y_hat[pred_var] .== 0)
171168
end
172169

173170
# solve model

src/optimizers/gradient.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ end
2525
function train_with_gradient!(
2626
model::Model,
2727
X::Matrix{<:Real},
28-
Y::Dict{<:Forecast, <:Vector},
28+
Y::Dict{<:Forecast,<:Vector},
2929
params::Dict{Symbol,Any},
3030
)
3131
# extract params

src/optimizers/nelder_mead.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Optim
33
function train_with_nelder_mead!(
44
model::Model,
55
X::Matrix{<:Real},
6-
Y::Dict{<:Forecast, <:Vector},
6+
Y::Dict{<:Forecast,<:Vector},
77
params::Dict{Symbol,Any},
88
)
99

src/predictive_model.jl

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,12 @@ include("variable_indexed_structs.jl")
1111
1212
Get the ordered output variables from the input-output map.
1313
"""
14-
function get_ordered_output_variables(input_output_map::Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}})
14+
function get_ordered_output_variables(
15+
input_output_map::Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},
16+
)
1517
return reduce(
16-
vcat,
17-
[
18-
reduce(vcat, values(iomap))
19-
for iomap in input_output_map
20-
]
18+
vcat,
19+
[reduce(vcat, values(iomap)) for iomap in input_output_map],
2120
)
2221
end
2322

@@ -26,15 +25,11 @@ end
2625
2726
Get the input indices from the input-output map.
2827
"""
29-
function get_input_indices(input_output_map::Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}})
28+
function get_input_indices(
29+
input_output_map::Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},
30+
)
3031
return unique(
31-
reduce(
32-
vcat,
33-
[
34-
reduce(vcat, keys(iomap))
35-
for iomap in input_output_map
36-
]
37-
)
32+
reduce(vcat, [reduce(vcat, keys(iomap)) for iomap in input_output_map]),
3833
)
3934
end
4035

@@ -43,7 +38,9 @@ end
4338
4439
Get the maximum input index from the input-output maps.
4540
"""
46-
function get_max_input_index(input_output_map::Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}})
41+
function get_max_input_index(
42+
input_output_map::Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},
43+
)
4744
return maximum(get_input_indices(input_output_map))
4845
end
4946

@@ -82,14 +79,20 @@ julia> pred_model = PredictiveModel(
8279
"""
8380
struct PredictiveModel
8481
networks::Union{Vector{<:Flux.Chain},Vector{<:Flux.Dense}}
85-
input_output_map::Union{Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},Nothing}
82+
input_output_map::Union{
83+
Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},
84+
Nothing,
85+
}
8686
output_variables::Union{Vector{<:Forecast},Nothing}
8787
input_size::Int
8888
output_size::Int
8989

9090
function PredictiveModel(
9191
networks::Union{Vector{<:Flux.Chain},Vector{<:Flux.Dense}},
92-
input_output_map::Union{Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},Nothing},
92+
input_output_map::Union{
93+
Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},
94+
Nothing,
95+
},
9396
output_variables::Union{Vector{<:Forecast},Nothing},
9497
input_size::Int,
9598
output_size::Int,
@@ -99,7 +102,7 @@ struct PredictiveModel
99102
input_output_map,
100103
output_variables,
101104
input_size,
102-
output_size
105+
output_size,
103106
)
104107
end
105108
end
@@ -112,7 +115,10 @@ from Flux models and input/output map.
112115
"""
113116
function PredictiveModel(
114117
networks::Union{Vector{<:Flux.Chain},Vector{<:Flux.Dense}},
115-
input_output_map::Union{Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},Nothing},
118+
input_output_map::Union{
119+
Vector{<:Dict{Vector{Int},<:Vector{<:Forecast}}},
120+
Nothing,
121+
},
116122
)
117123
output_variables = get_ordered_output_variables(input_output_map)
118124
input_size = get_max_input_index(input_output_map)
@@ -122,7 +128,7 @@ function PredictiveModel(
122128
input_output_map,
123129
output_variables,
124130
input_size,
125-
output_size
131+
output_size,
126132
)
127133
end
128134

@@ -231,7 +237,7 @@ specifying that only the networks field is trainable.
231237
Flux.trainable(model::PredictiveModel) = (networks = model.networks,)
232238

233239
# Tells Flux to only look at the 'network' field when setting up or traversing
234-
@Functors.functor PredictiveModel (networks,)
240+
Functors.@functor PredictiveModel (networks,)
235241

236242
"""
237243
(model::PredictiveModel)(X::AbstractMatrix, ignore_index::Bool = false)
@@ -280,7 +286,7 @@ end
280286
"""
281287
(model::PredictiveModel)(x::AbstractVector, ignore_index::Bool = false)
282288
283-
Predict the output of the model for a given input vector.
289+
Predict the output of the model for a given input vector.
284290
If the model has no input-output map, the network is applied directly to the input.
285291
If ignore_index is true, the output variables are not returned.
286292
"""
@@ -375,6 +381,10 @@ function apply_gradient!(
375381
X::Matrix{<:Real},
376382
opt_state,
377383
)
378-
return apply_gradient!(model, dCdy[model.output_variables].data, X, opt_state)
384+
return apply_gradient!(
385+
model,
386+
dCdy[model.output_variables].data,
387+
X,
388+
opt_state,
389+
)
379390
end
380-

src/simulation.jl

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,12 @@ function compute_single_step_cost(
44
yhat::VariableIndexedVector,
55
)
66
# set forecast params as prediction output
7-
MOI.set.(model.plan, POI.ParameterValue(), model.plan_forecast_params, yhat[model.forecast_vars].data)
7+
MOI.set.(
8+
model.plan,
9+
POI.ParameterValue(),
10+
model.plan_forecast_params,
11+
yhat[model.forecast_vars].data,
12+
)
813
# optimize plan model
914
optimize!(model.plan)
1015
# check for solution and fix assess policy vars
@@ -81,7 +86,7 @@ Compute the cost function (C) based on the model predictions and the true values
8186
function compute_cost(
8287
model::Model,
8388
X::Matrix{<:Real},
84-
Y::Dict{<:Forecast, <:Vector},
89+
Y::Dict{<:Forecast,<:Vector},
8590
with_gradients::Bool = false,
8691
aggregate::Bool = true,
8792
)
@@ -100,7 +105,7 @@ function compute_cost(
100105
dCdy = VariableIndexedVector{Float32}(undef, model.forecast_vars)
101106
dC = VariableIndexedMatrix{Float32}(undef, model.forecast_vars, T)
102107

103-
function _get_index_y(Y::Dict{<:Forecast, <:Vector}, idx::Int)
108+
function _get_index_y(Y::Dict{<:Forecast,<:Vector}, idx::Int)
104109
var_index = Vector{Forecast}(undef, model.forecast.output_size)
105110
y_values = Vector{Real}(undef, model.forecast.output_size)
106111
for (i, (fvar, vals)) in enumerate(Y)
@@ -110,7 +115,10 @@ function compute_cost(
110115
return VariableIndexedVector(y_values, var_index)
111116
end
112117

113-
function _compute_step(y::VariableIndexedVector, yhat::VariableIndexedVector)
118+
function _compute_step(
119+
y::VariableIndexedVector,
120+
yhat::VariableIndexedVector,
121+
)
114122
c = compute_single_step_cost(model, y, yhat)
115123
if with_gradients
116124
dc = compute_single_step_gradient(model, dCdz, dCdy)

0 commit comments

Comments
 (0)