Skip to content

Commit 776c2a3

Browse files
committed
change Y input to dictionary with forecast variables as keys
1 parent de925de commit 776c2a3

File tree

8 files changed

+107
-44
lines changed

8 files changed

+107
-44
lines changed

src/ApplicationDrivenLearning.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,7 @@ Train model using given data and options.
251251
function train!(
252252
model::Model,
253253
X::Matrix{<:Real},
254-
y::Matrix{<:Real},
254+
y::Dict{<:Forecast, <:Vector},
255255
options::Options,
256256
)
257257
if options.mode == NelderMeadMode

src/optimizers/bilevel.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using BilevelJuMP
55
function solve_bilevel(
66
model::Model,
77
X::Matrix{<:Real},
8-
Y::Matrix{<:Real},
8+
Y::Dict{<:Forecast, <:Vector},
99
params::Dict{Symbol,Any},
1010
)
1111

@@ -23,7 +23,7 @@ function solve_bilevel(
2323
end
2424

2525
# parameters
26-
T = size(Y, 1)
26+
T = size(X, 1)
2727

2828
# lower model variables
2929
low_var_map =
@@ -107,22 +107,21 @@ function solve_bilevel(
107107
@objective(Upper(bilevel_model), post_obj_sense, up_obj)
108108

109109
# fix upper model observations
110-
i_obs_var = 1
111-
for obs_var in assess_forecast_vars(model)
110+
for obs_var in model.forecast_vars
112111
@constraint(
113112
Upper(bilevel_model),
114-
up_var_map[obs_var] - Y[1:T, i_obs_var] .== 0
113+
up_var_map[obs_var.assess] - Y[obs_var] .== 0
115114
)
116-
i_obs_var += 1
117115
end
118116

119117
# implement predictive model expression iterating through
120118
# models and layers to create predictive expression
121119
npreds = size(model.forecast.networks, 1)
122120
predictive_model_vars = [Dict{Int,Any}() for ipred = 1:npreds]
123-
y_hat = Matrix{Any}(undef, size(Y, 1), size(Y, 2))
121+
# y_hat = Matrix{Any}(undef, size(Y, 1), size(Y, 2))
122+
y_hat = VariableIndexedMatrix{Any}(nothing, model.forecast_vars, T)
124123
for ipred = 1:npreds
125-
layers_inpt = Dict{Any,Any}(
124+
layers_inpt = Dict{Vector{Forecast},Matrix{Any}}(
126125
output_idx => X[1:T, input_idx] for (input_idx, output_idx) in
127126
model.forecast.input_output_map[ipred]
128127
)
@@ -158,19 +157,17 @@ function solve_bilevel(
158157
i_layer += 1
159158
end
160159
for (output_idx, prediction) in layers_inpt
161-
y_hat[:, output_idx] = prediction
160+
y_hat[output_idx] = prediction
162161
end
163162
end
164163

165164
# and apply prediction on lower model as constraint
166-
ipred_var_count = 1
167-
for pred_var in plan_forecast_vars(model)
168-
low_pred_var = low_var_map[pred_var]
165+
for pred_var in model.forecast_vars
166+
low_pred_var = low_var_map[pred_var.plan]
169167
@constraint(
170168
Lower(bilevel_model),
171-
low_pred_var .- y_hat[:, ipred_var_count] .== 0
169+
low_pred_var .- y_hat[pred_var] .== 0
172170
)
173-
ipred_var_count += 1
174171
end
175172

176173
# solve model

src/optimizers/gradient.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ Compute assess cost and cost gradient (with respect to predicted values) based
55
on incomplete batch of examples.
66
"""
77
function stochastic_compute(model, X, Y, batch, compute_full_cost::Bool)
8-
C, dC = compute_cost(model, X[batch, :], Y[batch, :], true)
8+
Y_batch = Dict(k => v[batch] for (k, v) in Y)
9+
C, dC = compute_cost(model, X[batch, :], Y_batch, true)
910
if compute_full_cost
1011
C = compute_cost(model, X, Y, false)
1112
end
@@ -24,7 +25,7 @@ end
2425
function train_with_gradient!(
2526
model::Model,
2627
X::Matrix{<:Real},
27-
Y::Matrix{<:Real},
28+
Y::Dict{<:Forecast, <:Vector},
2829
params::Dict{Symbol,Any},
2930
)
3031
# extract params

src/optimizers/nelder_mead.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Optim
33
function train_with_nelder_mead!(
44
model::Model,
55
X::Matrix{<:Real},
6-
Y::Matrix{<:Real},
6+
Y::Dict{<:Forecast, <:Vector},
77
params::Dict{Symbol,Any},
88
)
99

src/simulation.jl

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
function compute_single_step_cost(
22
model::Model,
3-
y::Vector{<:Real},
3+
y::VariableIndexedVector,
44
yhat::VariableIndexedVector,
55
)
66
# set forecast params as prediction output
7-
MOI.set.(model.plan, POI.ParameterValue(), model.plan_forecast_params, yhat[model.forecast_vars])
7+
MOI.set.(model.plan, POI.ParameterValue(), model.plan_forecast_params, yhat[model.forecast_vars].data)
88
# optimize plan model
99
optimize!(model.plan)
1010
# check for solution and fix assess policy vars
@@ -18,7 +18,7 @@ function compute_single_step_cost(
1818
throw(e)
1919
end
2020
# fix assess forecast vars on observer values
21-
fix.(assess_forecast_vars(model), y; force = true)
21+
fix.(model.forecast_vars.assess, y[model.forecast_vars].data; force = true)
2222
# optimize assess model
2323
optimize!(model.assess)
2424
# check for optimization
@@ -38,7 +38,7 @@ Computes the gradient of the cost function (C) with respect to the predictions (
3838
function compute_single_step_gradient(
3939
model::Model,
4040
dCdz::Vector{<:Real},
41-
dCdy::Vector{<:Real},
41+
dCdy::VariableIndexedVector{<:Real},
4242
)
4343
dCdz .= dual.(model.assess[:assess_policy_fix])
4444
DiffOpt.empty_input_sensitivities!(model.plan)
@@ -51,12 +51,12 @@ function compute_single_step_gradient(
5151
)
5252
end
5353
DiffOpt.reverse_differentiate!(model.plan)
54-
for j = 1:size(model.forecast_vars, 1)
55-
dCdy[j] =
54+
for fv in model.forecast_vars
55+
dCdy[fv] =
5656
MOI.get(
5757
model.plan,
5858
DiffOpt.ReverseConstraintSet(),
59-
ParameterRef(model.plan_forecast_params[j]),
59+
ParameterRef(fv.plan),
6060
).value
6161
end
6262

@@ -81,26 +81,36 @@ Compute the cost function (C) based on the model predictions and the true values
8181
function compute_cost(
8282
model::Model,
8383
X::Matrix{<:Real},
84-
Y::Matrix{<:Real},
84+
Y::Dict{<:Forecast, <:Vector},
8585
with_gradients::Bool = false,
8686
aggregate::Bool = true,
8787
)
8888

8989
# data size assertions
9090
@assert size(X)[2] == model.forecast.input_size "Input size mismatch"
91-
@assert size(Y)[2] == model.forecast.output_size "Output size mismatch"
91+
@assert length(Y) == model.forecast.output_size "Output size mismatch"
9292

9393
# build model variables if necessary
9494
build(model)
9595

9696
# init parameters
97-
T = size(Y)[1]
98-
C = zeros(T)
99-
dC = zeros((T, model.forecast.output_size))
97+
T = length.(collect(values(Y)))[1]
98+
C = Vector{Float32}(undef, T)
10099
dCdz = Vector{Float32}(undef, size(model.policy_vars, 1))
101-
dCdy = Vector{Float32}(undef, model.forecast.output_size)
100+
dCdy = VariableIndexedVector{Float32}(undef, model.forecast_vars)
101+
dC = VariableIndexedMatrix{Float32}(undef, model.forecast_vars, T)
102102

103-
function _compute_step(y::Vector{<:Real}, yhat::VariableIndexedVector)
103+
function _get_index_y(Y::Dict{<:Forecast, <:Vector}, idx::Int)
104+
var_index = Vector{Forecast}(undef, model.forecast.output_size)
105+
y_values = Vector{Real}(undef, model.forecast.output_size)
106+
for (i, (fvar, vals)) in enumerate(Y)
107+
var_index[i] = fvar
108+
y_values[i] = vals[idx]
109+
end
110+
return VariableIndexedVector(y_values, var_index)
111+
end
112+
113+
function _compute_step(y::VariableIndexedVector, yhat::VariableIndexedVector)
104114
c = compute_single_step_cost(model, y, yhat)
105115
if with_gradients
106116
dc = compute_single_step_gradient(model, dCdz, dCdy)
@@ -114,15 +124,15 @@ function compute_cost(
114124
115125
# main loop to compute cost
116126
for t = 1:T
117-
result = _compute_step(Y[t, :], Yhat[t])
118-
C[t] += result[1]
119-
dC[t, :] .+= result[2]
127+
result = _compute_step(_get_index_y(Y, t), Yhat[t])
128+
C[t] = result[1]
129+
dC[t] = result[2]
120130
end
121131
122132
# aggregate cost if requested
123133
if aggregate
124134
C = sum(C) / T
125-
dC = sum(dC, dims = 1)[1, :] / T
135+
dC = sum(dC) / T
126136
end
127137
128138
if with_gradients

src/variable_indexed_structs.jl

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import LinearAlgebra
2+
import Base./
23

34
"""
45
VariableIndexedVector(data::Vector{T}, index::Vector{<:Forecast})
@@ -34,6 +35,9 @@ Base.length(v::VariableIndexedVector) = length(v.data)
3435
Base.getindex(v::VariableIndexedVector, i::Int) = v.data[i]
3536
Base.setindex!(v::VariableIndexedVector, val, i::Int) = (v.data[i] = val)
3637

38+
# define Base rigth divide function (v / 2)
39+
/(v::VariableIndexedVector, i::Number) = VariableIndexedVector(v.data / i, v.index)
40+
3741
# define dot product of two VariableIndexedVectors
3842
function LinearAlgebra.dot(v1::VariableIndexedVector, v2::VariableIndexedVector)
3943
@assert length(v1) == length(v2) "Vectors must have the same length"
@@ -98,6 +102,23 @@ struct VariableIndexedMatrix{T} <: AbstractMatrix{T}
98102
@assert length(unique(row_index)) == length(row_index) "Variables must be unique"
99103
new{T}(data, row_index)
100104
end
105+
106+
function VariableIndexedMatrix{T}(::UndefInitializer, index::Vector{<:Forecast}, n::Real) where T
107+
return new{T}(Matrix{T}(undef, length(index), n), index)
108+
end
109+
110+
function VariableIndexedMatrix{T}(::Nothing, index::Vector{<:Forecast}, n::Real) where T
111+
return new{T}(Matrix{T}(nothing, length(index), n), index)
112+
end
113+
end
114+
115+
# helper to find index of a Forecast variable
116+
function _get_idx(m::VariableIndexedMatrix, var::Forecast)
117+
i = findfirst(isequal(var), m.row_index)
118+
if isnothing(i)
119+
throw(KeyError(var))
120+
end
121+
return i
101122
end
102123

103124
Base.size(m::VariableIndexedMatrix) = size(m.data)
@@ -106,11 +127,44 @@ Base.size(m::VariableIndexedMatrix) = size(m.data)
106127
Base.getindex(m::VariableIndexedMatrix, i::Int, j::Int) = m.data[i, j]
107128
Base.setindex!(m::VariableIndexedMatrix, val, i::Int, j::Int) = (m.data[i, j] = val)
108129

109-
# column lookup (get column 2) {M[:, 2]}
130+
# column lookup (get column 2) {M[2]}
110131
function Base.getindex(m::VariableIndexedMatrix, c::Int)
111132
return VariableIndexedVector(m.data[:, c], m.row_index)
112133
end
113134

135+
# row lookup (get values from variable) {M[forecast_var]}
136+
function Base.getindex(m::VariableIndexedMatrix, var::Forecast)
137+
return m.data[_get_idx(m, var), :]
138+
end
139+
140+
# multi-row lookup {M[[f_var_1, f_var_2]]}
141+
function Base.getindex(m::VariableIndexedMatrix, vars::Vector{<:Forecast})
142+
return VariableIndexedMatrix(m.data[[_get_idx(m, var) for var in vars], :], vars)
143+
end
144+
145+
# set column 2 values for all var indices {M[2] = [1,2,3]}
146+
function Base.setindex!(m::VariableIndexedMatrix, values::VariableIndexedVector, c::Int)
147+
m.data[:, c] = values[m.row_index]
148+
end
149+
150+
# set values for a single var index {M[forecast_var] = [1,2,3]}
151+
function Base.setindex!(m::VariableIndexedMatrix, values::Vector, var::Forecast)
152+
m.data[_get_idx(m, var), :] = values
153+
end
154+
155+
# set values for a subset of var indices {M[[f_var_1, f_var_2]] = [[1 2 3]; [4 5 6]]}
156+
function Base.setindex!(m::VariableIndexedMatrix, values::Matrix, vars::Vector{<:Forecast})
157+
m.data[[_get_idx(m, var) for var in vars], :] = values
158+
end
159+
160+
# define sum of matrix by summing all values for each variable
161+
function Base.sum(m::VariableIndexedMatrix)
162+
return VariableIndexedVector(sum(m.data, dims=2)[:, 1], m.row_index)
163+
end
164+
165+
# define Base rigth divide function (M / 2)
166+
/(m::ApplicationDrivenLearning.VariableIndexedMatrix, i::Number) = ApplicationDrivenLearning.VariableIndexedMatrix(m.data / i, m.row_index)
167+
114168
# define dot product of a VariableIndexedVectors and a VariableIndexedMatrix
115169
function LinearAlgebra.dot(v1::VariableIndexedVector, m2::VariableIndexedMatrix)
116170
@assert length(v1) == size(m2, 1) "Vector must have the same length as the number of rows in matrix"

test/test_gradient.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
# basic model for testing gradient mode
2-
X = Float32.(ones(1, 1))
3-
Y = Float32.(ones(1, 1))
42
model = ApplicationDrivenLearning.Model()
53
@variables(model, begin
64
x >= 0, ApplicationDrivenLearning.Policy
@@ -11,6 +9,8 @@ end)
119
set_optimizer(model, HiGHS.Optimizer)
1210
set_silent(model)
1311
ApplicationDrivenLearning.set_forecast_model(model, Chain(Dense(1 => 1)))
12+
X = Float32.(ones(1, 1))
13+
Y = Dict(d => Float32.(ones(1)))
1414

1515
@testset "GradientMode Stop Rules" begin
1616
# epochs

test/test_newsvendor.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,6 @@
11
c = 5.0
22
q = 9.0
33
r = 4.0
4-
X = ones(1, 1)
5-
Y = 50 * ones(1, 1)
6-
best_decision = y = Y[1, 1]
7-
best_cost = (c - q) * y
84

95
model = ApplicationDrivenLearning.Model()
106
@variables(model, begin
@@ -44,6 +40,11 @@ set_optimizer(model, HiGHS.Optimizer)
4440
set_silent(model)
4541
nn = Chain(Dense(1 => 1; bias = false, init = (size...) -> rand(size...)))
4642

43+
X = ones(1, 1)
44+
Y = Dict(d => [50.0])
45+
best_decision = y = Y[d][1]
46+
best_cost = (c - q) * y
47+
4748
@testset "Newsvendor BilevelMode" begin
4849
ApplicationDrivenLearning.set_forecast_model(
4950
model,
@@ -87,7 +88,7 @@ end
8788
opt = ApplicationDrivenLearning.Options(
8889
ApplicationDrivenLearning.GradientMode;
8990
rule = Flux.Adam(1.0),
90-
epochs = 150,
91+
epochs = 200,
9192
batch_size = -1,
9293
verbose = false,
9394
)

0 commit comments

Comments
 (0)