Skip to content

Commit 9c52f18

Browse files
committed
add aggregate arg on compute_cost funcion for allowing detailed cost returns
1 parent ea7455e commit 9c52f18

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ApplicationDrivenLearning"
22
uuid = "0856f1c8-ef17-4e14-9230-2773e47a789e"
33
authors = ["Giovanni Amorim", "Joaquim Garcia"]
4-
version = "0.1.1"
4+
version = "0.1.2"
55

66
[deps]
77
BilevelJuMP = "485130c0-026e-11ea-0f1a-6992cd14145c"

src/simulation.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ function compute_cost(
8383
X::Matrix{<:Real},
8484
Y::Matrix{<:Real},
8585
with_gradients::Bool = false,
86+
aggregate::Bool = true,
8687
)
8788

8889
# data size assertions
@@ -93,9 +94,9 @@ function compute_cost(
9394
build(model)
9495

9596
# init parameters
96-
C = 0
9797
T = size(Y)[1]
98-
dC = zeros(model.forecast.output_size)
98+
C = zeros(T)
99+
dC = zeros((T, model.forecast.output_size))
99100
dCdz = Vector{Float32}(undef, size(model.policy_vars, 1))
100101
dCdy = Vector{Float32}(undef, model.forecast.output_size)
101102

@@ -114,8 +115,14 @@ function compute_cost(
114115
# main loop to compute cost
115116
for t = 1:T
116117
result = _compute_step(Y[t, :], Yhat[t, :])
117-
C += result[1] ./ T
118-
dC .+= result[2] ./ T
118+
C[t] += result[1]
119+
dC[t, :] .+= result[2]
120+
end
121+
122+
# aggregate cost if requested
123+
if aggregate
124+
C = sum(C) / T
125+
dC = sum(dC, dims = 1)[1, :] / T
119126
end
120127

121128
if with_gradients

0 commit comments

Comments
 (0)