Skip to content

Commit 62bc90d

Browse files
committed
initial changes from JuliaFormatter
1 parent e9bf122 commit 62bc90d

File tree

12 files changed

+173
-217
lines changed

12 files changed

+173
-217
lines changed

docs/make.jl

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,10 @@ push!(LOAD_PATH, "../src")
44
using ApplicationDrivenLearning
55

66
makedocs(;
7-
modules=[ApplicationDrivenLearning],
8-
doctest=false,
9-
clean=true,
10-
sitename="ApplicationDrivenLearning.jl",
11-
authors="Giovanni Amorim, Joaquim Garcia",
12-
pages=[
13-
"Home" => "index.md",
14-
"API Reference" => "reference.md"
15-
]
16-
)
7+
modules = [ApplicationDrivenLearning],
8+
doctest = false,
9+
clean = true,
10+
sitename = "ApplicationDrivenLearning.jl",
11+
authors = "Giovanni Amorim, Joaquim Garcia",
12+
pages = ["Home" => "index.md", "API Reference" => "reference.md"],
13+
)

src/ApplicationDrivenLearning.jl

Lines changed: 28 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ struct Policy{T}
2222
end
2323

2424
+(p1::Policy, p2::Policy) = Policy(p1.plan + p2.plan, p1.assess + p2.assess)
25-
*(c::Number, p::Policy) = Policy(c*p.plan, c*p.assess)
25+
*(c::Number, p::Policy) = Policy(c * p.plan, c * p.assess)
2626

2727
"""
2828
Forecast{T}
@@ -35,7 +35,7 @@ struct Forecast{T}
3535
end
3636

3737
+(p1::Forecast, p2::Forecast) = Forecast(p1.plan + p2.plan, p1.assess + p2.assess)
38-
*(c::Number, p::Forecast) = Forecast(c*p.plan, c*p.assess)
38+
*(c::Number, p::Forecast) = Forecast(c * p.plan, c * p.assess)
3939

4040
"""
4141
Model <: JuMP.AbstractModel
@@ -46,7 +46,7 @@ missing forecast model and default settings.
4646
mutable struct Model <: JuMP.AbstractModel
4747
plan::JuMP.Model
4848
assess::JuMP.Model
49-
forecast::Union{PredictiveModel, Nothing}
49+
forecast::Union{PredictiveModel,Nothing}
5050

5151
# variable arrays
5252
policy_vars::Vector{Policy}
@@ -62,9 +62,14 @@ mutable struct Model <: JuMP.AbstractModel
6262
assess = JuMP.Model()
6363

6464
return new(
65-
plan, assess, nothing,
66-
Vector{Policy}(), Vector{Forecast}(), Vector{JuMP.VariableRef}(),
67-
Dict{Symbol,Any}(), false
65+
plan,
66+
assess,
67+
nothing,
68+
Vector{Policy}(),
69+
Vector{Forecast}(),
70+
Vector{JuMP.VariableRef}(),
71+
Dict{Symbol,Any}(),
72+
false,
6873
)
6974
end
7075
end
@@ -102,8 +107,8 @@ Sets Chain, Dense or custom PredictiveModel object as
102107
forecast model.
103108
"""
104109
function set_forecast_model(
105-
model::Model,
106-
network::Union{PredictiveModel, Flux.Chain, Flux.Dense}
110+
model::Model,
111+
network::Union{PredictiveModel,Flux.Chain,Flux.Dense},
107112
)
108113
if typeof(network) == PredictiveModel
109114
forecast = network
@@ -131,13 +136,13 @@ function build_plan_model_forecast_params(model::Model)
131136
# adds parametrized forecast variables using MOI.Parameter
132137
forecast_size = size(model.forecast_vars)[1]
133138
model.plan_forecast_params = @variable(
134-
model.plan,
139+
model.plan,
135140
_forecast[1:forecast_size] in MOI.Parameter.(zeros(forecast_size))
136141
)
137142
# fixes old and new prediction variables together
138143
@constraint(
139-
model.plan,
140-
plan_forecast_fix,
144+
model.plan,
145+
plan_forecast_fix,
141146
model.plan_forecast_params .== plan_forecast_vars(model)
142147
)
143148
end
@@ -146,11 +151,7 @@ end
146151
Creates new constraint to assess model that fixes policy variables.
147152
"""
148153
function build_assess_model_policy_constraint(model::Model)
149-
@constraint(
150-
model.assess,
151-
assess_policy_fix,
152-
assess_policy_vars(model) .== 0
153-
)
154+
@constraint(model.assess, assess_policy_fix, assess_policy_vars(model) .== 0)
154155
end
155156

156157
"""
@@ -183,12 +184,7 @@ include("optimizers/bilevel.jl")
183184
184185
Train model using given data and options.
185186
"""
186-
function train!(
187-
model::Model,
188-
X::Matrix{<:Real},
189-
y::Matrix{<:Real},
190-
options::Options,
191-
)
187+
function train!(model::Model, X::Matrix{<:Real}, y::Matrix{<:Real}, options::Options)
192188
if options.mode == NelderMeadMode
193189
return train_with_nelder_mead!(model, X, y, options.params)
194190
elseif options.mode == GradientMode
@@ -207,8 +203,14 @@ function train!(
207203
end
208204
end
209205

210-
export Model, PredictiveModel, Plan, Assess,
211-
Policy, Forecast,
212-
set_forecast_model, forecast,
213-
compute_cost, train!
206+
export Model,
207+
PredictiveModel,
208+
Plan,
209+
Assess,
210+
Policy,
211+
Forecast,
212+
set_forecast_model,
213+
forecast,
214+
compute_cost,
215+
train!
214216
end

src/flux_utils.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Flux
55
66
Extract the parameters of a Flux model (Flux.Chain or Flux.Dense) into a single vector.
77
"""
8-
function extract_flux_params(model::Union{Flux.Chain, Flux.Dense})
8+
function extract_flux_params(model::Union{Flux.Chain,Flux.Dense})
99
θ = Flux.params(model)
1010
return reduce(vcat, [vec(p) for p in θ])
1111
end
@@ -15,7 +15,10 @@ end
1515
1616
Return model after fixing the parameters from an adequate vector of parameters.
1717
"""
18-
function fix_flux_params_single_model(model::Union{Flux.Chain, Flux.Dense}, θ::Vector{<:Real})
18+
function fix_flux_params_single_model(
19+
model::Union{Flux.Chain,Flux.Dense},
20+
θ::Vector{<:Real},
21+
)
1922
i = 1
2023
for p in Flux.params(model)
2124
psize = prod(size(p))
@@ -30,10 +33,7 @@ end
3033
3134
Return iterable of models after fixing the parameters from an adequate vector of parameters.
3235
"""
33-
function fix_flux_params_multi_model(
34-
models,
35-
θ::Vector{<:Real}
36-
)
36+
function fix_flux_params_multi_model(models, θ::Vector{<:Real})
3737
i = 1
3838
for model in models
3939
for p in Flux.params(model)

src/jump.jl

Lines changed: 21 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,21 @@ function JuMP.build_variable(
1111
::Type{Policy};
1212
kwargs...,
1313
)
14-
return PolicyInfo(
15-
info,
16-
info,
17-
kwargs
18-
)
14+
return PolicyInfo(info, info, kwargs)
1915
end
2016

21-
function JuMP.add_variable(
22-
model::Model,
23-
policy_info::PolicyInfo,
24-
name::String
25-
)
17+
function JuMP.add_variable(model::Model, policy_info::PolicyInfo, name::String)
2618
policy = Policy(
2719
JuMP.add_variable(
28-
model.plan,
29-
JuMP.ScalarVariable(policy_info.plan),
30-
name * "_plan"
20+
model.plan,
21+
JuMP.ScalarVariable(policy_info.plan),
22+
name * "_plan",
3123
),
3224
JuMP.add_variable(
33-
model.assess,
34-
JuMP.ScalarVariable(policy_info.assess),
35-
name * "_assess"
36-
)
25+
model.assess,
26+
JuMP.ScalarVariable(policy_info.assess),
27+
name * "_assess",
28+
),
3729
)
3830
push!(model.policy_vars, policy)
3931
return policy
@@ -52,29 +44,21 @@ function JuMP.build_variable(
5244
::Type{Forecast};
5345
kwargs...,
5446
)
55-
return ForecastInfo(
56-
info,
57-
info,
58-
kwargs
59-
)
47+
return ForecastInfo(info, info, kwargs)
6048
end
6149

62-
function JuMP.add_variable(
63-
model::Model,
64-
forecast_info::ForecastInfo,
65-
name::String
66-
)
50+
function JuMP.add_variable(model::Model, forecast_info::ForecastInfo, name::String)
6751
forecast = Forecast(
6852
JuMP.add_variable(
69-
model.plan,
70-
JuMP.ScalarVariable(forecast_info.plan),
71-
name * "_plan"
53+
model.plan,
54+
JuMP.ScalarVariable(forecast_info.plan),
55+
name * "_plan",
7256
),
7357
JuMP.add_variable(
74-
model.assess,
75-
JuMP.ScalarVariable(forecast_info.assess),
76-
name * "_assess"
77-
)
58+
model.assess,
59+
JuMP.ScalarVariable(forecast_info.assess),
60+
name * "_assess",
61+
),
7862
)
7963
push!(model.forecast_vars, forecast)
8064
return forecast
@@ -124,12 +108,12 @@ end
124108

125109
JuMP.object_dictionary(model::Model) = model.obj_dict
126110

127-
function JuMP.set_optimizer(model::Model, builder, evaluate_duals::Bool=true)
111+
function JuMP.set_optimizer(model::Model, builder, evaluate_duals::Bool = true)
128112
# set diffopt optimizer for plan model
129113
new_diff_optimizer = DiffOpt.diff_optimizer(builder)
130114
JuMP.set_optimizer(
131115
model.plan,
132-
() -> POI.Optimizer(new_diff_optimizer; evaluate_duals=evaluate_duals)
116+
() -> POI.Optimizer(new_diff_optimizer; evaluate_duals = evaluate_duals),
133117
)
134118

135119
# basic setting for assess model
@@ -160,4 +144,4 @@ function Base.print(io::IO, model::Model)
160144
else
161145
println(model.forecast.network)
162146
end
163-
end
147+
end

0 commit comments

Comments
 (0)