Skip to content

Commit 3d252c4

Browse files
Update API ProblemIterator (#18)
* update dependencies and allow for no POI update * add pre_solve_hook * set first tag * rm compat 1.6
1 parent 895d1a7 commit 3d252c4

File tree

7 files changed

+155
-38
lines changed

7 files changed

+155
-38
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.6'
2221
- '1.9'
2322
os:
2423
- ubuntu-latest

Project.toml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "L2O"
22
uuid = "e1d8bfa7-c465-446a-84b9-451470f6e76c"
33
authors = ["andrewrosemberg <[email protected]> and contributors"]
4-
version = "1.2.0-DEV"
4+
version = "1.0.0"
55

66
[deps]
77
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
@@ -21,12 +21,12 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2121
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2222

2323
[compat]
24-
Arrow = "2"
25-
CSV = "0.10"
26-
JuMP = "1"
27-
ParametricOptInterface = "0.7"
24+
Arrow = "^2"
25+
CSV = "^0.10"
26+
JuMP = "^1"
27+
ParametricOptInterface = "^0.8"
2828
Zygote = "^0.6.68"
29-
julia = "1.6"
29+
julia = "^1.9"
3030

3131
[extras]
3232
CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2"

src/FullyConnected.jl

Lines changed: 47 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -66,14 +66,20 @@ end
6666
# @forward((ConvexRegressor, :model), MLJFlux.Regressor)
6767

6868
# Define a container to hold any optimiser specific parameters (if any):
69-
struct ConvexRule <: Flux.Optimise.AbstractOptimiser
70-
rule::Flux.Optimise.AbstractOptimiser
69+
struct ConvexRule <: Optimisers.AbstractRule
70+
rule::Optimisers.AbstractRule
7171
tol::Real
7272
end
73-
function ConvexRule(rule::Flux.Optimise.AbstractOptimiser; tol=1e-6)
73+
function ConvexRule(rule::Optimisers.AbstractRule; tol=1e-6)
7474
return ConvexRule(rule, tol)
7575
end
7676

77+
Optimisers.init(o::ConvexRule, x::AbstractArray) = Optimisers.init(o.rule, x)
78+
79+
function Optimisers.apply!(o::ConvexRule, mvel, x::AbstractArray{T}, dx) where T
80+
return Optimisers.apply!(o.rule, mvel, x, dx)
81+
end
82+
7783
"""
7884
function make_convex!(chain::PairwiseFusion; tol = 1e-6)
7985
@@ -102,24 +108,48 @@ function make_convex!(model::Chain; tol=1e-6)
102108
end
103109
end
104110

105-
function MLJFlux.train!(
106-
model::MLJFlux.MLJFluxDeterministic, penalty, chain, optimiser::ConvexRule, X, y
107-
)
111+
function MLJFlux.train(
112+
model,
113+
chain,
114+
optimiser::ConvexRule,
115+
optimiser_state,
116+
epochs,
117+
verbosity,
118+
X,
119+
y,
120+
)
121+
108122
loss = model.loss
123+
124+
# intitialize and start progress meter:
125+
meter = MLJFlux.Progress(epochs + 1, dt=0, desc="Optimising neural net:",
126+
barglyphs=MLJFlux.BarGlyphs("[=> ]"), barlen=25, color=:yellow)
127+
verbosity != 1 || MLJFlux.next!(meter)
128+
129+
# initiate history:
109130
n_batches = length(y)
110-
training_loss = zero(Float32)
111-
for i in 1:n_batches
112-
parameters = Flux.params(chain)
113-
gs = Flux.gradient(parameters) do
114-
yhat = chain(X[i])
115-
batch_loss = loss(yhat, y[i]) + penalty(parameters) / n_batches
116-
training_loss += batch_loss
117-
return batch_loss
118-
end
119-
Flux.update!(optimiser.rule, parameters, gs)
131+
132+
losses = (loss(chain(X[i]), y[i]) for i in 1:n_batches)
133+
history = [mean(losses),]
134+
135+
for i in 1:epochs
136+
chain, optimiser_state, current_loss = MLJFlux.train_epoch(
137+
model,
138+
chain,
139+
optimiser,
140+
optimiser_state,
141+
X,
142+
y,
143+
)
120144
make_convex!(chain; tol=optimiser.tol)
145+
verbosity < 2 ||
146+
@info "Loss is $(round(current_loss; sigdigits=4))"
147+
verbosity != 1 || next!(meter)
148+
push!(history, current_loss)
121149
end
122-
return training_loss / n_batches
150+
151+
return chain, optimiser_state, history
152+
123153
end
124154

125155
function train!(model, loss, opt_state, X, Y; _batchsize=32, shuffle=true)

src/datasetgen.jl

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ end
105105

106106
abstract type AbstractProblemIterator end
107107

108+
abstract type AbstractParameterType end
109+
110+
abstract type POIParamaterType <: AbstractParameterType end
111+
112+
abstract type JuMPNLPParameterType <: AbstractParameterType end
113+
114+
abstract type JuMPParameterType <: AbstractParameterType end
115+
108116
"""
109117
ProblemIterator(ids::Vector{UUID}, pairs::Dict{VariableRef, Vector{Real}})
110118
@@ -115,24 +123,30 @@ struct ProblemIterator{T<:Real} <: AbstractProblemIterator
115123
ids::Vector{UUID}
116124
pairs::Dict{VariableRef,Vector{T}}
117125
early_stop::Function
126+
param_type::Type{<:AbstractParameterType}
127+
pre_solve_hook::Function
118128
function ProblemIterator(
119129
ids::Vector{UUID},
120130
pairs::Dict{VariableRef,Vector{T}},
121131
early_stop::Function=(args...) -> false,
132+
param_type::Type{<:AbstractParameterType}=POIParamaterType,
133+
pre_solve_hook::Function=(args...) -> nothing
122134
) where {T<:Real}
123135
model = JuMP.owner_model(first(keys(pairs)))
124136
for (p, val) in pairs
125137
@assert length(ids) == length(val)
126138
end
127-
return new{T}(model, ids, pairs, early_stop)
139+
return new{T}(model, ids, pairs, early_stop, param_type, pre_solve_hook)
128140
end
129141
end
130142

131143
function ProblemIterator(
132-
pairs::Dict{VariableRef,Vector{T}}; early_stop::Function=(args...) -> false
133-
) where {T<:Real}
144+
pairs::Dict{VariableRef,Vector{T}}; early_stop::Function=(args...) -> false,
145+
pre_solve_hook::Function=(args...) -> nothing,
146+
param_type::Type{<:AbstractParameterType}=POIParamaterType,
134147
ids = [uuid1() for _ in 1:length(first(values(pairs)))]
135-
return ProblemIterator(ids, pairs, early_stop)
148+
) where {T<:Real}
149+
return ProblemIterator(ids, pairs, early_stop, param_type, pre_solve_hook)
136150
end
137151

138152
"""
@@ -174,7 +188,8 @@ end
174188

175189
function load(model_file::AbstractString, input_file::AbstractString, ::Type{T};
176190
batch_size::Union{Nothing, Integer}=nothing,
177-
ignore_ids::Vector{UUID}=UUID[]
191+
ignore_ids::Vector{UUID}=UUID[],
192+
param_type::Type{<:AbstractParameterType}=JuMPParameterType
178193
) where {T<:FileType}
179194
# Load full set
180195
df = load(input_file, T)
@@ -191,31 +206,40 @@ function load(model_file::AbstractString, input_file::AbstractString, ::Type{T};
191206
# No batch
192207
if isnothing(batch_size)
193208
pairs = _dataframe_to_dict(df, model_file)
194-
return ProblemIterator(ids, pairs)
209+
return ProblemIterator(pairs; ids=ids, param_type=param_type)
195210
end
196211
# Batch
197212
num_batches = ceil(Int, length(ids) / batch_size)
198213
idx_range = (i) -> (i-1)*batch_size+1:min(i*batch_size, length(ids))
199-
return (i) -> ProblemIterator(ids[idx_range(i)], _dataframe_to_dict(df[idx_range(i), :], model_file)), num_batches
214+
return (i) -> ProblemIterator(_dataframe_to_dict(df[idx_range(i), :], model_file);
215+
ids=ids[idx_range(i)], param_type=param_type), num_batches
200216
end
201217

202218
"""
203219
update_model!(model::JuMP.Model, p::VariableRef, val::Real)
204220
205221
Update the value of a parameter in a JuMP model.
206222
"""
207-
function update_model!(model::JuMP.Model, p::VariableRef, val)
223+
function update_model!(::Type{POIParamaterType}, model::JuMP.Model, p::VariableRef, val)
208224
return MOI.set(model, POI.ParameterValue(), p, val)
209225
end
210226

227+
function update_model!(::Type{JuMPNLPParameterType}, model::JuMP.Model, p::VariableRef, val)
228+
return set_parameter_value(p, val)
229+
end
230+
231+
function update_model!(::Type{JuMPParameterType}, model::JuMP.Model, p::VariableRef, val)
232+
return fix(p, val)
233+
end
234+
211235
"""
212236
update_model!(model::JuMP.Model, pairs::Dict, idx::Integer)
213237
214238
Update the values of parameters in a JuMP model.
215239
"""
216-
function update_model!(model::JuMP.Model, pairs::Dict, idx::Integer)
240+
function update_model!(model::JuMP.Model, pairs::Dict, idx::Integer, param_type::Type{<:AbstractParameterType})
217241
for (p, val) in pairs
218-
update_model!(model, p, val[idx])
242+
update_model!(param_type, model, p, val[idx])
219243
end
220244
end
221245

@@ -228,7 +252,8 @@ function solve_and_record(
228252
problem_iterator::ProblemIterator, recorder::Recorder, idx::Integer
229253
)
230254
model = problem_iterator.model
231-
update_model!(model, problem_iterator.pairs, idx)
255+
problem_iterator.pre_solve_hook(model)
256+
update_model!(model, problem_iterator.pairs, idx, problem_iterator.param_type)
232257
optimize!(model)
233258
status = recorder.filterfn(model)
234259
early_stop_bool = problem_iterator.early_stop(model, status, recorder)

src/samplers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ end
119119
Load the parameters from a JuMP model.
120120
"""
121121
function load_parameters(model::JuMP.Model)
122-
cons = constraint_object.(all_constraints(model, VariableRef, MOI.Parameter{Float64}))
122+
cons = constraint_object.([all_constraints(model, VariableRef, MOI.Parameter{Float64}); all_constraints(model, VariableRef, MOI.EqualTo{Float64})])
123123
parameters = [cons[i].func for i in 1:length(cons)]
124124
vals = [cons[i].set.value for i in 1:length(cons)]
125125
return parameters, vals

test/datasetgen.jl

Lines changed: 64 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Test dataset generation for different filetypes
66
"""
77
function test_problem_iterator(path::AbstractString)
8-
@testset "Dataset Generation Type: $filetype" for filetype in [CSVFile, ArrowFile]
8+
@testset "Dataset Generation (POI) Type: $filetype" for filetype in [CSVFile, ArrowFile]
99
# The problem to iterate over
1010
model = JuMP.Model(() -> POI.Optimizer(HiGHS.Optimizer()))
1111
@variable(model, x)
@@ -55,6 +55,17 @@ function test_problem_iterator(path::AbstractString)
5555
@test num_p * successfull_solves == 1
5656
end
5757

58+
@testset "pre_solve_hook" begin
59+
file_dual_output = joinpath(path, "test_$(string(uuid1()))_output") # file path
60+
recorder_dual = Recorder{filetype}(file_dual_output; dual_variables=[cons])
61+
sum_p = 0
62+
problem_iterator = ProblemIterator(
63+
Dict(p => collect(1.0:num_p)); pre_solve_hook=(args...) -> sum_p += 1
64+
)
65+
successfull_solves = solve_batch(problem_iterator, recorder_dual)
66+
@test sum_p == num_p
67+
end
68+
5869
@testset "solve_batch" begin
5970
successfull_solves = solve_batch(problem_iterator, recorder)
6071

@@ -96,6 +107,58 @@ function test_problem_iterator(path::AbstractString)
96107
end
97108
end
98109
end
110+
@testset "Dataset Generation JuMP" begin
111+
model = JuMP.Model(HiGHS.Optimizer)
112+
@variable(model, x)
113+
p = @variable(model, _p)
114+
@constraint(model, cons, x + _p >= 3)
115+
@objective(model, Min, 2x)
116+
num_p = 10
117+
batch_id = string(uuid1())
118+
problem_iterator = ProblemIterator(Dict(p => collect(1.0:num_p)); param_type=L2O.JuMPParameterType)
119+
file_output = joinpath(path, "test_$(batch_id)_output") # file path
120+
recorder = Recorder{ArrowFile}(
121+
file_output; primal_variables=[x], dual_variables=[cons]
122+
)
123+
successfull_solves = solve_batch(problem_iterator, recorder)
124+
iter_files = readdir(joinpath(path))
125+
iter_files = filter(x -> occursin(string(ArrowFile), x), iter_files)
126+
file_outs = [
127+
joinpath(path, file) for
128+
file in iter_files if occursin("$(batch_id)_output", file)
129+
]
130+
# test output file
131+
df = Arrow.Table(file_outs)
132+
@test length(df) == 8
133+
@test length(df[1]) == num_p * successfull_solves
134+
rm.(file_outs)
135+
end
136+
@testset "Dataset Generation JuMPNLP" begin
137+
model = JuMP.Model(Ipopt.Optimizer)
138+
@variable(model, x)
139+
p = @variable(model, _p in MOI.Parameter(1.0))
140+
@constraint(model, cons, x + _p >= 3)
141+
@objective(model, Min, 2x)
142+
num_p = 10
143+
batch_id = string(uuid1())
144+
problem_iterator = ProblemIterator(Dict(p => collect(1.0:num_p)); param_type=L2O.JuMPNLPParameterType)
145+
file_output = joinpath(path, "test_$(batch_id)_output") # file path
146+
recorder = Recorder{ArrowFile}(
147+
file_output; primal_variables=[x], dual_variables=[cons]
148+
)
149+
successfull_solves = solve_batch(problem_iterator, recorder)
150+
iter_files = readdir(joinpath(path))
151+
iter_files = filter(x -> occursin(string(ArrowFile), x), iter_files)
152+
file_outs = [
153+
joinpath(path, file) for
154+
file in iter_files if occursin("$(batch_id)_output", file)
155+
]
156+
# test output file
157+
df = Arrow.Table(file_outs)
158+
@test length(df) == 8
159+
@test length(df[1]) == num_p * successfull_solves
160+
rm.(file_outs)
161+
end
99162
end
100163

101164
function test_load(model_file::AbstractString, input_file::AbstractString, ::Type{T}, ids::Vector{UUID};

test/test_flux_forecaster.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function test_flux_forecaster(file_in::AbstractString, file_out::AbstractString)
2121
rng=123,
2222
epochs=20,
2323
optimiser=ConvexRule(
24-
Flux.Optimise.Adam(0.001, (0.9, 0.999), 1.0e-8, IdDict{Any,Any}())
24+
Optimisers.Adam()
2525
),
2626
)
2727

0 commit comments

Comments
 (0)