Skip to content

Commit 5915b5f

Browse files
Update tests
1 parent 4841f44 commit 5915b5f

File tree

4 files changed

+56
-17
lines changed

4 files changed

+56
-17
lines changed

src/Nodes/Ensembles/WeightedEnsembleNode.jl

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
Base.@kwdef mutable struct WeightedEnsembleNode{N<:NetworkNode, P, A<:Real} <: EnsembleNode
1+
import ..Analysis: TemporalCrossSection
2+
3+
Base.@kwdef mutable struct WeightedEnsembleNode{N<:NetworkNode,P,A<:Real} <: EnsembleNode
24
name::String
35
area::A
46

57
instances::Array{N} = NetworkNode[]
68
weights::Array{P} = [
7-
Param(1.0; bounds=[0.0,1.0])
9+
Param(1.0; bounds=[0.0, 1.0])
810
]
911

1012
# Default to normalized weighted sum
@@ -25,8 +27,8 @@ function create_node(
2527
end
2628

2729
function WeightedEnsembleNode(nodes::Vector{<:NetworkNode}; weights::Vector{Float64},
28-
bounds::Union{Nothing, Vector{Tuple}}=nothing,
29-
comb_method::Union{Nothing, Function}=nothing)::WeightedEnsembleNode
30+
bounds::Union{Nothing,Vector{Tuple}}=nothing,
31+
comb_method::Union{Nothing,Function}=nothing)::WeightedEnsembleNode
3032
if isnothing(bounds)
3133
num_nodes = length(nodes)
3234
@assert(length(weights) == num_nodes, "Number of nodes do not match provided number of weights")
@@ -39,7 +41,7 @@ function WeightedEnsembleNode(nodes::Vector{<:NetworkNode}; weights::Vector{Floa
3941
p_weights = [Param(w; bounds=b) for (w, b) in zip(weights, bounds)]
4042

4143
n1 = nodes[1]
42-
tmp = WeightedEnsembleNode{NetworkNode, Param, Float64}(;
44+
tmp = WeightedEnsembleNode{NetworkNode,Param,Float64}(;
4345
name=n1.name,
4446
area=n1.area,
4547
instances=nodes,
@@ -157,12 +159,12 @@ function calibrate!(
157159
ensemble::WeightedEnsembleNode,
158160
climate::Climate,
159161
calib_data::DataFrame,
160-
metric::Union{AbstractDict{String, C}, C};
162+
metric::Union{AbstractDict{String,C},C};
161163
kwargs...
162164
) where {C<:Function}
163165
return invoke(
164166
calibrate!,
165-
Tuple{NetworkNode, Climate, DataFrame, typeof(metric)},
167+
Tuple{NetworkNode,Climate,DataFrame,typeof(metric)},
166168
ensemble,
167169
climate,
168170
calib_data,
@@ -208,7 +210,7 @@ function apply_bias_correction(
208210
period=monthday
209211
) where {T<:Real}
210212
dates = timesteps(climate)
211-
tcs = TemporalCrossSection(dates, obs, ensemble.outflow)
213+
tcs = TemporalCrossSection(dates, obs, ensemble.outflow; period)
212214

213215
return ensemble.outflow .+ (-tcs.ts)
214216
end

test/test_calibration.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ using Dates, DataFrames, CSV, YAML
33
using Streamfall
44

55

6-
DATA_PATH = joinpath(@__DIR__, "data/hymod")
6+
DATA_PATH = joinpath(dirname(dirname(pathof(Streamfall))), "test/data/hymod")
77

88
@testset "Single node calibration" begin
99
# Load and generate stream network
@@ -13,8 +13,8 @@ DATA_PATH = joinpath(@__DIR__, "data/hymod")
1313
# Load climate data
1414
date_format = "YYYY-mm-dd"
1515
obs_data = CSV.File(joinpath(DATA_PATH, "leaf_river_data.csv"),
16-
comment="#",
17-
dateformat=date_format) |> DataFrame
16+
comment="#",
17+
dateformat=date_format) |> DataFrame
1818

1919
# Column names must be the gauge name
2020
rename!(obs_data, ["leaf_river_outflow" => "leaf_river"])

test/test_data_op.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ using Streamfall
66

77
@testset "Data alignment" begin
88

9-
here = @__DIR__
9+
here = joinpath(dirname(dirname(pathof(Streamfall))), "test")
1010
climate_data = joinpath(here, "data/campaspe/climate/climate_historic.csv")
1111
dam_data_loc = joinpath(here, "data/campaspe/gauges")
1212

test/test_networks.jl

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ using DataFrames
77
using Streamfall
88

99

10-
DATA_PATH = joinpath(@__DIR__, "data/hymod")
10+
DATA_PATH = joinpath(dirname(dirname(pathof(Streamfall))), "test/data/hymod")
1111

1212

1313
@testset "Network loading" begin
@@ -71,7 +71,7 @@ end
7171

7272

7373
@testset "Reloading a network spec (HyMod)" begin
74-
spec::Dict{String, Union{Dict{String, Any}, Any}} = Streamfall.extract_network_spec(sn)
74+
spec::Dict{String,Union{Dict{String,Any},Any}} = Streamfall.extract_network_spec(sn)
7575
@test spec isa Dict
7676
@test haskey(spec, "leaf_river") || "Known node does not exist"
7777

@@ -97,7 +97,7 @@ end
9797
)
9898

9999
# Column name must match node name
100-
rename!(obs_data, ["leaf_river_outflow"=>"leaf_river"])
100+
rename!(obs_data, ["leaf_river_outflow" => "leaf_river"])
101101
climate_data = obs_data[:, ["Date", "leaf_river_P", "leaf_river_ET"]]
102102
climate = Climate(climate_data, "_P", "_ET")
103103

@@ -109,9 +109,46 @@ end
109109

110110
@testset "Recursing IHACRESNode upstream" begin
111111
begin
112-
include("../examples/run_nodes.jl")
113-
# Ensure example does not error out
112+
data_path = joinpath(dirname(dirname(pathof(Streamfall))), "test/data/campaspe/")
113+
114+
# Load and generate stream network
115+
sn = load_network("Example Network", joinpath(data_path, "campaspe_network.yml"))
116+
117+
climate = Climate("../test/data/campaspe/climate/climate.csv", "_rain", "_evap")
118+
119+
# Historic flows and dam level data
120+
calib_data = CSV.read(
121+
joinpath(data_path, "gauges", "outflow_and_level.csv"),
122+
DataFrame;
123+
comment="#"
124+
)
114125

126+
# Historic extractions from the dam
127+
extraction_data = CSV.read(
128+
joinpath(data_path, "gauges", "dam_extraction.csv"),
129+
DataFrame;
130+
comment="#"
131+
)
132+
133+
@info "Running example stream..."
134+
135+
dam_id, dam_node = sn["406000"]
136+
Streamfall.run_node!(sn, dam_id, climate; extraction=extraction_data)
137+
138+
# Extract data for comparison with 1-year burn-in period
139+
dam_obs = calib_data[:, "406000"][366:end]
140+
dam_sim = dam_node.level[366:end]
141+
142+
nnse_score = Streamfall.NNSE(dam_obs, dam_sim)
143+
nse_score = Streamfall.NSE(dam_obs, dam_sim)
144+
rmse_score = Streamfall.RMSE(dam_obs, dam_sim)
145+
146+
@info "Obj Func Scores:" rmse_score nnse_score nse_score
147+
148+
nse = round(nse_score, digits=4)
149+
rmse = round(rmse_score, digits=4)
150+
151+
# Ensure example does not error out
115152
reset!(sn)
116153
run_basin!(sn, climate; extraction=extraction_data)
117154
@test Streamfall.RMSE(dam_obs, sn[dam_id].level[366:end]) < 2.5

0 commit comments

Comments
 (0)