Skip to content

Commit 4d9fd03

Browse files
Tentative implementation of GRC ensembles
Granger-Ramanathan variant C which has been shown to outperform mean weighted ensembles
1 parent 1586564 commit 4d9fd03

File tree

3 files changed

+116
-10
lines changed

3 files changed

+116
-10
lines changed

src/Nodes/Ensembles/EnsembleNode.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,5 +6,15 @@ using Statistics
66
abstract type EnsembleNode <: NetworkNode end
77

88
include("WeightedEnsembleNode.jl")
9+
include("GREnsembleNode.jl")
910
# include("StateEnsembleNode.jl")
1011

12+
function prep_state!(node::EnsembleNode, timesteps::Int64)::Nothing
13+
node.outflow = fill(0.0, timesteps)
14+
15+
for n in node.instances
16+
prep_state!(n, timesteps)
17+
end
18+
19+
return nothing
20+
end
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
Base.@kwdef mutable struct GREnsembleNode{N<:NetworkNode, P, A<:Real} <: EnsembleNode
2+
name::String
3+
area::A
4+
5+
instances::Array{N} = NetworkNode[]
6+
7+
# GRC method
8+
comb_method::Function = grc_combine
9+
10+
outflow::Array{A} = []
11+
12+
obj_func::Function = obj_func
13+
end
14+
15+
function GREnsembleNode(nodes::Vector{<:NetworkNode})::GREnsembleNode
16+
n1 = nodes[1]
17+
gr_node = GREnsembleNode{NetworkNode, Param, Float64}(;
18+
name=n1.name,
19+
area=n1.area,
20+
instances=nodes
21+
)
22+
23+
return gr_node
24+
end
25+
26+
function create_node(
27+
node::Type{<:GREnsembleNode},
28+
nodes::Vector{<:NetworkNode};
29+
kwargs...
30+
)
31+
return GREnsembleNode(nodes; kwargs...)
32+
end
33+
34+
function grc_weights(X::Matrix{T}, y::Vector{T}) where T<:Real
35+
# Add constant term for bias correction
36+
X_aug = hcat(ones(size(X,1)), X)
37+
38+
# Solve normal equations: β = (X'X)^(-1)X'y
39+
β = inv(X_aug' * X_aug) * X_aug' * y
40+
41+
# Split bias term and weights
42+
bias = β[1]
43+
weights = β[2:end]
44+
45+
return weights, bias
46+
end
47+
48+
function grc_combine(X::Matrix{T}, weights::Vector{T}, bias::T) where T<:Real
49+
# Apply weights and bias correction
50+
return X * weights .+ bias
51+
end
52+
53+
function calibrate_instances!(
54+
ensemble::GREnsembleNode,
55+
climate::Climate,
56+
calib_data::DataFrame,
57+
metric::Union{F,AbstractDict{String,F}};
58+
kwargs...
59+
) where {F}
60+
61+
# Calibrate individual instances first
62+
for node in ensemble.instances
63+
calibrate!(node, climate, calib_data, metric; kwargs...)
64+
end
65+
66+
# Then determine GRC
67+
return calibrate!(ensemble, climate, calib_data, metric; kwargs...)
68+
end
69+
70+
function calibrate!(
71+
ensemble::GREnsembleNode,
72+
climate::Climate,
73+
calib_data::DataFrame,
74+
metric::Union{C,AbstractDict{String,C}}; # Unused, added to maintain consistent interface
75+
kwargs...
76+
) where {C<:Function}
77+
78+
for inst in ensemble.instances
79+
run_node!(inst, climate)
80+
end
81+
82+
X = Matrix(hcat([m.outflow for m in ensemble.instances]...))
83+
weights, bias = grc_weights(X, calib_data[:, ensemble.name])
84+
85+
ensemble.comb_method = (X) -> grc_combine(hcat(X...), weights, bias)
86+
87+
return nothing
88+
end
89+
90+
function run_node!(ensemble::GREnsembleNode, climate::Climate; inflow=nothing, extraction=nothing, exchange=nothing)
91+
for inst in ensemble.instances
92+
run_node!(inst, climate; inflow=inflow, extraction=extraction, exchange=exchange)
93+
end
94+
95+
X = hcat([m.outflow for m in ensemble.instances]...)'
96+
97+
ensemble.outflow = ensemble.comb_method([inst.outflow for inst in ensemble.instances])
98+
end
99+
100+
function run_timestep!(node::WeightedEnsembleNode, rain, et, ts; inflow=0.0, extraction=0.0, exchange=0.0)
101+
for inst in node.instances
102+
run_timestep!(inst, rain, et, ts; inflow=inflow, extraction=extraction, exchange=exchange)
103+
end
104+
105+
node.outflow[ts] = node.comb_method([inst.outflow[ts] for inst in node.instances])
106+
end

src/Nodes/Ensembles/WeightedEnsembleNode.jl

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,16 +55,6 @@ function WeightedEnsembleNode(nodes::Vector{<:NetworkNode}; weights::Vector{Floa
5555
return tmp
5656
end
5757

58-
function prep_state!(node::WeightedEnsembleNode, timesteps::Int64)::Nothing
59-
node.outflow = fill(0.0, timesteps)
60-
61-
for n in node.instances
62-
prep_state!(n, timesteps)
63-
end
64-
65-
return nothing
66-
end
67-
6858

6959
function param_info(node::WeightedEnsembleNode; kwargs...)::Tuple
7060
values = Float64[w.val for w in node.weights]

0 commit comments

Comments
 (0)