Skip to content

Commit e76601d

Browse files
authored
Merge pull request #8 from LAMPSPUC/fix/formatter
Apply JuliaFormatter
2 parents e9bf122 + 34a7eb4 commit e76601d

16 files changed

+397
-342
lines changed

.JuliaFormatter.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Configuration file for JuliaFormatter.jl
2+
# For more information, see: https://domluna.github.io/JuliaFormatter.jl/stable/config/
3+
4+
indent = 4
5+
always_use_return = true
6+
margin = 80
7+
remove_extra_newlines = true
8+
short_to_long_function_def = true
9+
format_docstrings = true

docs/make.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@ push!(LOAD_PATH, "../src")
44
using ApplicationDrivenLearning
55

66
makedocs(;
7-
modules=[ApplicationDrivenLearning],
8-
doctest=false,
9-
clean=true,
10-
sitename="ApplicationDrivenLearning.jl",
11-
authors="Giovanni Amorim, Joaquim Garcia",
12-
pages=[
13-
"Home" => "index.md",
14-
"API Reference" => "reference.md"
15-
]
16-
)
7+
modules = [ApplicationDrivenLearning],
8+
doctest = false,
9+
clean = true,
10+
sitename = "ApplicationDrivenLearning.jl",
11+
authors = "Giovanni Amorim, Joaquim Garcia",
12+
pages = ["Home" => "index.md", "API Reference" => "reference.md"],
13+
)

src/ApplicationDrivenLearning.jl

Lines changed: 40 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
module ApplicationDrivenLearning
22

3-
43
using Flux
54
using JuMP
65
using DiffOpt
76
import ParametricOptInterface as POI
87
import Base.*, Base.+
98

10-
119
include("flux_utils.jl")
1210
include("predictive_model.jl")
1311

@@ -22,7 +20,7 @@ struct Policy{T}
2220
end
2321

2422
+(p1::Policy, p2::Policy) = Policy(p1.plan + p2.plan, p1.assess + p2.assess)
25-
*(c::Number, p::Policy) = Policy(c*p.plan, c*p.assess)
23+
*(c::Number, p::Policy) = Policy(c * p.plan, c * p.assess)
2624

2725
"""
2826
Forecast{T}
@@ -34,19 +32,21 @@ struct Forecast{T}
3432
assess::T
3533
end
3634

37-
+(p1::Forecast, p2::Forecast) = Forecast(p1.plan + p2.plan, p1.assess + p2.assess)
38-
*(c::Number, p::Forecast) = Forecast(c*p.plan, c*p.assess)
35+
function +(p1::Forecast, p2::Forecast)
36+
return Forecast(p1.plan + p2.plan, p1.assess + p2.assess)
37+
end
38+
*(c::Number, p::Forecast) = Forecast(c * p.plan, c * p.assess)
3939

4040
"""
4141
Model <: JuMP.AbstractModel
4242
43-
Create an empty ApplicationDrivenLearning.Model with empty plan and assess models,
44-
missing forecast model and default settings.
43+
Create an empty ApplicationDrivenLearning.Model with empty plan and assess
44+
models, missing forecast model and default settings.
4545
"""
4646
mutable struct Model <: JuMP.AbstractModel
4747
plan::JuMP.Model
4848
assess::JuMP.Model
49-
forecast::Union{PredictiveModel, Nothing}
49+
forecast::Union{PredictiveModel,Nothing}
5050

5151
# variable arrays
5252
policy_vars::Vector{Policy}
@@ -62,9 +62,14 @@ mutable struct Model <: JuMP.AbstractModel
6262
assess = JuMP.Model()
6363

6464
return new(
65-
plan, assess, nothing,
66-
Vector{Policy}(), Vector{Forecast}(), Vector{JuMP.VariableRef}(),
67-
Dict{Symbol,Any}(), false
65+
plan,
66+
assess,
67+
nothing,
68+
Vector{Policy}(),
69+
Vector{Forecast}(),
70+
Vector{JuMP.VariableRef}(),
71+
Dict{Symbol,Any}(),
72+
false,
6873
)
6974
end
7075
end
@@ -102,16 +107,16 @@ Sets Chain, Dense or custom PredictiveModel object as
102107
forecast model.
103108
"""
104109
function set_forecast_model(
105-
model::Model,
106-
network::Union{PredictiveModel, Flux.Chain, Flux.Dense}
110+
model::Model,
111+
network::Union{PredictiveModel,Flux.Chain,Flux.Dense},
107112
)
108113
if typeof(network) == PredictiveModel
109114
forecast = network
110115
else
111116
forecast = PredictiveModel(network)
112117
end
113118
@assert forecast.output_size == size(model.forecast_vars, 1)
114-
model.forecast = forecast
119+
return model.forecast = forecast
115120
end
116121

117122
"""
@@ -131,13 +136,13 @@ function build_plan_model_forecast_params(model::Model)
131136
# adds parametrized forecast variables using MOI.Parameter
132137
forecast_size = size(model.forecast_vars)[1]
133138
model.plan_forecast_params = @variable(
134-
model.plan,
139+
model.plan,
135140
_forecast[1:forecast_size] in MOI.Parameter.(zeros(forecast_size))
136141
)
137142
# fixes old and new prediction variables together
138143
@constraint(
139-
model.plan,
140-
plan_forecast_fix,
144+
model.plan,
145+
plan_forecast_fix,
141146
model.plan_forecast_params .== plan_forecast_vars(model)
142147
)
143148
end
@@ -147,8 +152,8 @@ Creates new constraint to assess model that fixes policy variables.
147152
"""
148153
function build_assess_model_policy_constraint(model::Model)
149154
@constraint(
150-
model.assess,
151-
assess_policy_fix,
155+
model.assess,
156+
assess_policy_fix,
152157
assess_policy_vars(model) .== 0
153158
)
154159
end
@@ -165,7 +170,7 @@ function build(model::Model)
165170

166171
# build plan model
167172
build_plan_model_forecast_params(model)
168-
build_assess_model_policy_constraint(model)
173+
return build_assess_model_policy_constraint(model)
169174
end
170175

171176
include("jump.jl")
@@ -184,9 +189,9 @@ include("optimizers/bilevel.jl")
184189
Train model using given data and options.
185190
"""
186191
function train!(
187-
model::Model,
188-
X::Matrix{<:Real},
189-
y::Matrix{<:Real},
192+
model::Model,
193+
X::Matrix{<:Real},
194+
y::Matrix{<:Real},
190195
options::Options,
191196
)
192197
if options.mode == NelderMeadMode
@@ -198,17 +203,23 @@ function train!(
198203
elseif options.mode == GradientMPIMode
199204
return train_with_gradient_mpi!(model, X, y, options.params)
200205
elseif options.mode == BilevelMode
201-
assert_msg = "BilevelMode not implemented for multiple forecasting models"
202-
@assert length(model.forecast.networks) == 1 assert_msg
206+
asr_msg = "BilevelMode not implemented for multiple forecasting models"
207+
@assert length(model.forecast.networks) == 1 asr_msg
203208
return solve_bilevel(model, X, y, options.params)
204209
else
205210
# should never get here
206211
throw(ArgumentError("Invalid optimization method"))
207212
end
208213
end
209214

210-
export Model, PredictiveModel, Plan, Assess,
211-
Policy, Forecast,
212-
set_forecast_model, forecast,
213-
compute_cost, train!
215+
export Model,
216+
PredictiveModel,
217+
Plan,
218+
Assess,
219+
Policy,
220+
Forecast,
221+
set_forecast_model,
222+
forecast,
223+
compute_cost,
224+
train!
214225
end

src/flux_utils.jl

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@ using Flux
33
"""
44
extract_flux_params(model)
55
6-
Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single vector.
6+
Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single
7+
vector.
78
"""
8-
function extract_flux_params(model::Union{Flux.Chain, Flux.Dense})
9+
function extract_flux_params(model::Union{Flux.Chain,Flux.Dense})
910
θ = Flux.params(model)
1011
return reduce(vcat, [vec(p) for p in θ])
1112
end
@@ -15,7 +16,10 @@ end
1516
1617
Return model after fixing the parameters from an adequate vector of parameters.
1718
"""
18-
function fix_flux_params_single_model(model::Union{Flux.Chain, Flux.Dense}, θ::Vector{<:Real})
19+
function fix_flux_params_single_model(
20+
model::Union{Flux.Chain,Flux.Dense},
21+
θ::Vector{<:Real},
22+
)
1923
i = 1
2024
for p in Flux.params(model)
2125
psize = prod(size(p))
@@ -28,12 +32,10 @@ end
2832
"""
2933
fix_flux_params_multi_model(models, θ)
3034
31-
Return iterable of models after fixing the parameters from an adequate vector of parameters.
35+
Return iterable of models after fixing the parameters from an adequate vector
36+
of parameters.
3237
"""
33-
function fix_flux_params_multi_model(
34-
models,
35-
θ::Vector{<:Real}
36-
)
38+
function fix_flux_params_multi_model(models, θ::Vector{<:Real})
3739
i = 1
3840
for model in models
3941
for p in Flux.params(model)
@@ -55,7 +57,7 @@ function has_params(layer)
5557
# Attempt to get parameters; if it works and isn't empty, return true
5658
return !isempty(Flux.params(layer))
5759
catch e
58-
# If there is an error (e.g., method not matching), assume no parameters
60+
# If there is an error (e.g. method not matching), assume no parameters
5961
return false
6062
end
6163
end

src/jump.jl

Lines changed: 31 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,21 @@ function JuMP.build_variable(
1111
::Type{Policy};
1212
kwargs...,
1313
)
14-
return PolicyInfo(
15-
info,
16-
info,
17-
kwargs
18-
)
14+
return PolicyInfo(info, info, kwargs)
1915
end
2016

21-
function JuMP.add_variable(
22-
model::Model,
23-
policy_info::PolicyInfo,
24-
name::String
25-
)
17+
function JuMP.add_variable(model::Model, policy_info::PolicyInfo, name::String)
2618
policy = Policy(
2719
JuMP.add_variable(
28-
model.plan,
29-
JuMP.ScalarVariable(policy_info.plan),
30-
name * "_plan"
20+
model.plan,
21+
JuMP.ScalarVariable(policy_info.plan),
22+
name * "_plan",
3123
),
3224
JuMP.add_variable(
33-
model.assess,
34-
JuMP.ScalarVariable(policy_info.assess),
35-
name * "_assess"
36-
)
25+
model.assess,
26+
JuMP.ScalarVariable(policy_info.assess),
27+
name * "_assess",
28+
),
3729
)
3830
push!(model.policy_vars, policy)
3931
return policy
@@ -52,29 +44,25 @@ function JuMP.build_variable(
5244
::Type{Forecast};
5345
kwargs...,
5446
)
55-
return ForecastInfo(
56-
info,
57-
info,
58-
kwargs
59-
)
47+
return ForecastInfo(info, info, kwargs)
6048
end
6149

6250
function JuMP.add_variable(
63-
model::Model,
64-
forecast_info::ForecastInfo,
65-
name::String
51+
model::Model,
52+
forecast_info::ForecastInfo,
53+
name::String,
6654
)
6755
forecast = Forecast(
6856
JuMP.add_variable(
69-
model.plan,
70-
JuMP.ScalarVariable(forecast_info.plan),
71-
name * "_plan"
57+
model.plan,
58+
JuMP.ScalarVariable(forecast_info.plan),
59+
name * "_plan",
7260
),
7361
JuMP.add_variable(
74-
model.assess,
75-
JuMP.ScalarVariable(forecast_info.assess),
76-
name * "_assess"
77-
)
62+
model.assess,
63+
JuMP.ScalarVariable(forecast_info.assess),
64+
name * "_assess",
65+
),
7866
)
7967
push!(model.forecast_vars, forecast)
8068
return forecast
@@ -100,11 +88,14 @@ end
10088

10189
# jump functions
10290
function JuMP.objective_sense(model::Model)
103-
@assert JuMP.objective_sense(model.plan) == JuMP.objective_sense(model.assess)
91+
@assert JuMP.objective_sense(model.plan) ==
92+
JuMP.objective_sense(model.assess)
10493
return JuMP.objective_sense(model.plan)
10594
end
10695

107-
JuMP.num_variables(m::Model) = JuMP.num_variables(m.plan) + JuMP.num_variables(m.assess)
96+
function JuMP.num_variables(m::Model)
97+
return JuMP.num_variables(m.plan) + JuMP.num_variables(m.assess)
98+
end
10899

109100
function JuMP.show_constraints_summary(io::IO, model::Model)
110101
println("Plan Model:")
@@ -124,21 +115,22 @@ end
124115

125116
JuMP.object_dictionary(model::Model) = model.obj_dict
126117

127-
function JuMP.set_optimizer(model::Model, builder, evaluate_duals::Bool=true)
118+
function JuMP.set_optimizer(model::Model, builder, evaluate_duals::Bool = true)
128119
# set diffopt optimizer for plan model
129120
new_diff_optimizer = DiffOpt.diff_optimizer(builder)
130121
JuMP.set_optimizer(
131122
model.plan,
132-
() -> POI.Optimizer(new_diff_optimizer; evaluate_duals=evaluate_duals)
123+
() ->
124+
POI.Optimizer(new_diff_optimizer; evaluate_duals = evaluate_duals),
133125
)
134126

135127
# basic setting for assess model
136-
JuMP.set_optimizer(model.assess, builder)
128+
return JuMP.set_optimizer(model.assess, builder)
137129
end
138130

139131
function JuMP.set_silent(model::Model)
140132
MOI.set(model.plan, MOI.Silent(), true)
141-
MOI.set(model.assess, MOI.Silent(), true)
133+
return MOI.set(model.assess, MOI.Silent(), true)
142134
end
143135

144136
function JuMP.num_variables(model::Model)
@@ -160,4 +152,4 @@ function Base.print(io::IO, model::Model)
160152
else
161153
println(model.forecast.network)
162154
end
163-
end
155+
end

0 commit comments

Comments
 (0)