Skip to content

Commit 34a7eb4

Browse files
committed
add .JuliaFormatter.toml and apply custom setting format
1 parent 62bc90d commit 34a7eb4

15 files changed

+270
-171
lines changed

.JuliaFormatter.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Configuration file for JuliaFormatter.jl
2+
# For more information, see: https://domluna.github.io/JuliaFormatter.jl/stable/config/
3+
4+
indent = 4
5+
always_use_return = true
6+
margin = 80
7+
remove_extra_newlines = true
8+
short_to_long_function_def = true
9+
format_docstrings = true

src/ApplicationDrivenLearning.jl

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
module ApplicationDrivenLearning
22

3-
43
using Flux
54
using JuMP
65
using DiffOpt
76
import ParametricOptInterface as POI
87
import Base.*, Base.+
98

10-
119
include("flux_utils.jl")
1210
include("predictive_model.jl")
1311

@@ -34,14 +32,16 @@ struct Forecast{T}
3432
assess::T
3533
end
3634

37-
+(p1::Forecast, p2::Forecast) = Forecast(p1.plan + p2.plan, p1.assess + p2.assess)
35+
function +(p1::Forecast, p2::Forecast)
36+
return Forecast(p1.plan + p2.plan, p1.assess + p2.assess)
37+
end
3838
*(c::Number, p::Forecast) = Forecast(c * p.plan, c * p.assess)
3939

4040
"""
4141
Model <: JuMP.AbstractModel
4242
43-
Create an empty ApplicationDrivenLearning.Model with empty plan and assess models,
44-
missing forecast model and default settings.
43+
Create an empty ApplicationDrivenLearning.Model with empty plan and assess
44+
models, missing forecast model and default settings.
4545
"""
4646
mutable struct Model <: JuMP.AbstractModel
4747
plan::JuMP.Model
@@ -116,7 +116,7 @@ function set_forecast_model(
116116
forecast = PredictiveModel(network)
117117
end
118118
@assert forecast.output_size == size(model.forecast_vars, 1)
119-
model.forecast = forecast
119+
return model.forecast = forecast
120120
end
121121

122122
"""
@@ -151,7 +151,11 @@ end
151151
Creates new constraint to assess model that fixes policy variables.
152152
"""
153153
function build_assess_model_policy_constraint(model::Model)
154-
@constraint(model.assess, assess_policy_fix, assess_policy_vars(model) .== 0)
154+
@constraint(
155+
model.assess,
156+
assess_policy_fix,
157+
assess_policy_vars(model) .== 0
158+
)
155159
end
156160

157161
"""
@@ -166,7 +170,7 @@ function build(model::Model)
166170

167171
# build plan model
168172
build_plan_model_forecast_params(model)
169-
build_assess_model_policy_constraint(model)
173+
return build_assess_model_policy_constraint(model)
170174
end
171175

172176
include("jump.jl")
@@ -184,7 +188,12 @@ include("optimizers/bilevel.jl")
184188
185189
Train model using given data and options.
186190
"""
187-
function train!(model::Model, X::Matrix{<:Real}, y::Matrix{<:Real}, options::Options)
191+
function train!(
192+
model::Model,
193+
X::Matrix{<:Real},
194+
y::Matrix{<:Real},
195+
options::Options,
196+
)
188197
if options.mode == NelderMeadMode
189198
return train_with_nelder_mead!(model, X, y, options.params)
190199
elseif options.mode == GradientMode
@@ -194,8 +203,8 @@ function train!(model::Model, X::Matrix{<:Real}, y::Matrix{<:Real}, options::Opt
194203
elseif options.mode == GradientMPIMode
195204
return train_with_gradient_mpi!(model, X, y, options.params)
196205
elseif options.mode == BilevelMode
197-
assert_msg = "BilevelMode not implemented for multiple forecasting models"
198-
@assert length(model.forecast.networks) == 1 assert_msg
206+
asr_msg = "BilevelMode not implemented for multiple forecasting models"
207+
@assert length(model.forecast.networks) == 1 asr_msg
199208
return solve_bilevel(model, X, y, options.params)
200209
else
201210
# should never get here

src/flux_utils.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ using Flux
33
"""
44
extract_flux_params(model)
55
6-
Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single vector.
6+
Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single
7+
vector.
78
"""
89
function extract_flux_params(model::Union{Flux.Chain,Flux.Dense})
910
θ = Flux.params(model)
@@ -31,7 +32,8 @@ end
3132
"""
3233
fix_flux_params_multi_model(models, θ)
3334
34-
Return iterable of models after fixing the parameters from an adequate vector of parameters.
35+
Return iterable of models after fixing the parameters from an adequate vector
36+
of parameters.
3537
"""
3638
function fix_flux_params_multi_model(models, θ::Vector{<:Real})
3739
i = 1
@@ -55,7 +57,7 @@ function has_params(layer)
5557
# Attempt to get parameters; if it works and isn't empty, return true
5658
return !isempty(Flux.params(layer))
5759
catch e
58-
# If there is an error (e.g., method not matching), assume no parameters
60+
# If there is an error (e.g. method not matching), assume no parameters
5961
return false
6062
end
6163
end

src/jump.jl

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,11 @@ function JuMP.build_variable(
4747
return ForecastInfo(info, info, kwargs)
4848
end
4949

50-
function JuMP.add_variable(model::Model, forecast_info::ForecastInfo, name::String)
50+
function JuMP.add_variable(
51+
model::Model,
52+
forecast_info::ForecastInfo,
53+
name::String,
54+
)
5155
forecast = Forecast(
5256
JuMP.add_variable(
5357
model.plan,
@@ -84,11 +88,14 @@ end
8488

8589
# jump functions
8690
function JuMP.objective_sense(model::Model)
87-
@assert JuMP.objective_sense(model.plan) == JuMP.objective_sense(model.assess)
91+
@assert JuMP.objective_sense(model.plan) ==
92+
JuMP.objective_sense(model.assess)
8893
return JuMP.objective_sense(model.plan)
8994
end
9095

91-
JuMP.num_variables(m::Model) = JuMP.num_variables(m.plan) + JuMP.num_variables(m.assess)
96+
function JuMP.num_variables(m::Model)
97+
return JuMP.num_variables(m.plan) + JuMP.num_variables(m.assess)
98+
end
9299

93100
function JuMP.show_constraints_summary(io::IO, model::Model)
94101
println("Plan Model:")
@@ -113,16 +120,17 @@ function JuMP.set_optimizer(model::Model, builder, evaluate_duals::Bool = true)
113120
new_diff_optimizer = DiffOpt.diff_optimizer(builder)
114121
JuMP.set_optimizer(
115122
model.plan,
116-
() -> POI.Optimizer(new_diff_optimizer; evaluate_duals = evaluate_duals),
123+
() ->
124+
POI.Optimizer(new_diff_optimizer; evaluate_duals = evaluate_duals),
117125
)
118126

119127
# basic setting for assess model
120-
JuMP.set_optimizer(model.assess, builder)
128+
return JuMP.set_optimizer(model.assess, builder)
121129
end
122130

123131
function JuMP.set_silent(model::Model)
124132
MOI.set(model.plan, MOI.Silent(), true)
125-
MOI.set(model.assess, MOI.Silent(), true)
133+
return MOI.set(model.assess, MOI.Silent(), true)
126134
end
127135

128136
function JuMP.num_variables(model::Model)

src/optimizers/bilevel.jl

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ using JuMP
22
using Flux
33
using BilevelJuMP
44

5-
65
function solve_bilevel(
76
model::Model,
87
X::Matrix{<:Real},
@@ -27,10 +26,12 @@ function solve_bilevel(
2726
T = size(Y, 1)
2827

2928
# lower model variables
30-
low_var_map = Dict{JuMP.VariableRef,Vector{BilevelJuMP.BilevelVariableRef}}()
29+
low_var_map =
30+
Dict{JuMP.VariableRef,Vector{BilevelJuMP.BilevelVariableRef}}()
3131
for pre_var in all_variables(model.plan)
3232
low_var_name = string(name(pre_var), "_low")
33-
low_var_ref = @variable(Lower(bilevel_model), [1:T], base_name = low_var_name)
33+
low_var_ref =
34+
@variable(Lower(bilevel_model), [1:T], base_name = low_var_name)
3435
if has_lower_bound(pre_var)
3536
set_lower_bound.(low_var_ref, lower_bound(pre_var))
3637
end
@@ -45,7 +46,8 @@ function solve_bilevel(
4546
for post_var in all_variables(model.assess)
4647
if !(post_var in assess_policy_vars(model))
4748
up_var_name = string(name(post_var), "_up")
48-
up_var_ref = @variable(Upper(bilevel_model), [1:T], base_name = up_var_name)
49+
up_var_ref =
50+
@variable(Upper(bilevel_model), [1:T], base_name = up_var_name)
4951
if has_lower_bound(post_var)
5052
set_lower_bound.(up_var_ref, lower_bound(post_var))
5153
end
@@ -65,20 +67,30 @@ function solve_bilevel(
6567
end
6668

6769
# lower model base constraints
68-
for pre_con in
69-
JuMP.all_constraints(model.plan, include_variable_in_set_constraints = false)
70+
for pre_con in JuMP.all_constraints(
71+
model.plan,
72+
include_variable_in_set_constraints = false,
73+
)
7074
pre_con_func = JuMP.constraint_object(pre_con).func
7175
lhs = [value(x -> low_var_map[x][t], pre_con_func) for t = 1:T]
72-
@constraint(Lower(bilevel_model), lhs .∈ JuMP.constraint_object(pre_con).set)
76+
@constraint(
77+
Lower(bilevel_model),
78+
lhs .∈ JuMP.constraint_object(pre_con).set
79+
)
7380
end
7481

7582
# upper model base constraints
76-
for post_con in
77-
JuMP.all_constraints(model.assess, include_variable_in_set_constraints = false)
83+
for post_con in JuMP.all_constraints(
84+
model.assess,
85+
include_variable_in_set_constraints = false,
86+
)
7887
if name(post_con) != "assess_policy_fix"
7988
post_con_func = JuMP.constraint_object(post_con).func
8089
lhs = [value(x -> up_var_map[x][t], post_con_func) for t = 1:T]
81-
@constraint(Upper(bilevel_model), lhs .∈ JuMP.constraint_object(post_con).set)
90+
@constraint(
91+
Upper(bilevel_model),
92+
lhs .∈ JuMP.constraint_object(post_con).set
93+
)
8294
end
8395
end
8496

@@ -97,7 +109,10 @@ function solve_bilevel(
97109
# fix upper model observations
98110
i_obs_var = 1
99111
for obs_var in assess_forecast_vars(model)
100-
@constraint(Upper(bilevel_model), up_var_map[obs_var] - Y[1:T, i_obs_var] .== 0)
112+
@constraint(
113+
Upper(bilevel_model),
114+
up_var_map[obs_var] - Y[1:T, i_obs_var] .== 0
115+
)
101116
i_obs_var += 1
102117
end
103118

@@ -114,7 +129,10 @@ function solve_bilevel(
114129
if has_params(layer)
115130
# get size and parameters W and b
116131
(layer_size_out, layer_size_in) = size(layer.weight)
117-
W = @variable(Upper(bilevel_model), [1:layer_size_out, 1:layer_size_in])
132+
W = @variable(
133+
Upper(bilevel_model),
134+
[1:layer_size_out, 1:layer_size_in]
135+
)
118136
if layer.bias == false
119137
b = zeros(layer_size_out)
120138
else
@@ -123,7 +141,8 @@ function solve_bilevel(
123141
predictive_model_vars[i_layer] = Dict(:W => W, :b => b)
124142
# build layer output as next layer input
125143
for output_idx in values(model.forecast.input_output_map[1])
126-
layers_inpt[output_idx] = layer.σ(W * layers_inpt[output_idx]' .+ b)'
144+
layers_inpt[output_idx] =
145+
layer.σ(W * layers_inpt[output_idx]' .+ b)'
127146
end
128147
# if activation function layer, just apply
129148
elseif supertype(typeof(layer)) == Function
@@ -144,7 +163,10 @@ function solve_bilevel(
144163
ipred_var_count = 1
145164
for pred_var in plan_forecast_vars(model)
146165
low_pred_var = low_var_map[pred_var]
147-
@constraint(Lower(bilevel_model), low_pred_var .- y_hat[:, ipred_var_count] .== 0)
166+
@constraint(
167+
Lower(bilevel_model),
168+
low_pred_var .- y_hat[:, ipred_var_count] .== 0
169+
)
148170
ipred_var_count += 1
149171
end
150172

@@ -165,5 +187,8 @@ function solve_bilevel(
165187
ilayer += 1
166188
end
167189

168-
return Solution(objective_value(bilevel_model), extract_params(model.forecast))
190+
return Solution(
191+
objective_value(bilevel_model),
192+
extract_params(model.forecast),
193+
)
169194
end

src/optimizers/gradient.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
using Flux
22

33
"""
4-
Compute assess cost and cost gradient (with respect to predicted values) based on incomplete batch of examples.
4+
Compute assess cost and cost gradient (with respect to predicted values) based
5+
on incomplete batch of examples.
56
"""
67
function stochastic_compute(model, X, Y, batch, compute_full_cost::Bool)
78
C, dC = compute_cost(model, X[batch, :], Y[batch, :], true)
@@ -12,7 +13,8 @@ function stochastic_compute(model, X, Y, batch, compute_full_cost::Bool)
1213
end
1314

1415
"""
15-
Compute assess cost and cost gradient (with respect to predicted values) based on complete batch of examples.
16+
Compute assess cost and cost gradient (with respect to predicted values) based
17+
on complete batch of examples.
1618
"""
1719
function deterministic_compute(model, X, Y)
1820
C, dC = compute_cost(model, X, Y, true)
@@ -53,7 +55,13 @@ function train_with_gradient!(
5355

5456
if stochastic
5557
epochx = X[batches[epoch, :], :]
56-
C, dC = stochastic_compute(model, X, Y, batches[epoch, :], compute_full_cost)
58+
C, dC = stochastic_compute(
59+
model,
60+
X,
61+
Y,
62+
batches[epoch, :],
63+
compute_full_cost,
64+
)
5765
else
5866
epochx = X
5967
C, dC = deterministic_compute(model, X, Y)
@@ -80,7 +88,6 @@ function train_with_gradient!(
8088

8189
# take gradient step
8290
apply_gradient!(model.forecast, dC, epochx, rule)
83-
8491
end
8592

8693
# fix best model

src/optimizers/gradient_mpi.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ import ParametricOptInterface as POI
33
using MPI
44
import JobQueueMPI as JQM
55

6-
76
function train_with_gradient_mpi!(
87
model::Model,
98
X::Matrix{<:Real},
@@ -52,7 +51,6 @@ function train_with_gradient_mpi!(
5251
end
5352

5453
return step_cost, step_grad
55-
5654
end
5755

5856
# call optim as the controller
@@ -75,18 +73,21 @@ function train_with_gradient_mpi!(
7573
(v) -> compute_cost_and_gradients(v[1], v[2], true),
7674
[[curr_θ, i] for i in batches[epoch, :]],
7775
)
78-
dCdy = sum([r[2] for r in pmap_result_with_gradients]) ./ batch_size
76+
dCdy =
77+
sum([r[2] for r in pmap_result_with_gradients]) ./ batch_size
7978

8079
if compute_full_cost
8180
# broadcast `is_done = false` again
8281
MPI.bcast(is_done, MPI.COMM_WORLD)
8382

8483
# compute full cost
8584
pmap_result_without_gradients = JQM.pmap(
86-
(v) -> compute_cost_and_gradients(v[1], v[2], false),
85+
(v) ->
86+
compute_cost_and_gradients(v[1], v[2], false),
8787
[[curr_θ, i] for i = 1:T],
8888
)
89-
curr_C = sum([r[1] for r in pmap_result_without_gradients]) ./ T
89+
curr_C =
90+
sum([r[1] for r in pmap_result_without_gradients]) ./ T
9091
end
9192

9293
else

0 commit comments

Comments
 (0)