Skip to content

Commit a5afc24

Browse files
authored
Implement Brier Score/Brier Skill Score (#592)
* adds brier_score to probabilistic metrics, updates routine to calculate metric skill score * adds new ensemble test data, adds attribute to BS_ATTRS * updates test
1 parent c095b03 commit a5afc24

File tree

11 files changed

+284
-117
lines changed

11 files changed

+284
-117
lines changed

src/teehr/evaluation/metrics.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,11 +239,34 @@ def _post_process_metric_results(
239239
"""
240240
for model in include_metrics:
241241
if model.reference_configuration is not None:
242+
"""
242243
self.df = self._calculate_metric_skill_score(
243244
model.output_field_name,
244245
model.reference_configuration,
245246
group_by
246247
)
248+
"""
249+
# 1) get the original cols ahead of skill score join
250+
original_cols = self.df.columns
251+
# 2) calculate skill score sdf
252+
sdf = self._calculate_metric_skill_score(
253+
model.output_field_name,
254+
model.reference_configuration,
255+
group_by
256+
)
257+
# 3) remove original metric column from skill score sdf
258+
sdf = sdf.drop(model.output_field_name)
259+
# 3) get join columns
260+
join_cols = parse_fields_to_list(group_by)
261+
# 4) join returned table back to self.df, trim
262+
self.df = self.df.join(
263+
sdf,
264+
on=join_cols,
265+
how="left"
266+
).select(
267+
*original_cols,
268+
F.col(f"{model.output_field_name}_skill_score")
269+
)
247270

248271
if model.unpack_results:
249272
self.df = model.unpack_function(
@@ -292,11 +315,20 @@ def _calculate_metric_skill_score(
292315
temp_col = f"{config}_{metric_field}_skill"
293316
pivot_sdf = pivot_sdf.withColumn(
294317
temp_col,
295-
1 - F.col(config) / F.col(reference_configuration)
318+
1 - F.try_divide(F.col(config), F.col(reference_configuration))
296319
).withColumn(
297320
"configuration_name",
298321
F.lit(config)
299322
)
323+
# warn user if try_divide results in nulls (division by zero)
324+
null_count = pivot_sdf.filter(F.col(temp_col).isNull()).count()
325+
if null_count > 0:
326+
logger.warning(
327+
f"Division by zero encountered when calculating skill "
328+
f"score for configuration '{config}' relative to "
329+
f"reference configuration '{reference_configuration}'. "
330+
f"{null_count} null values were produced."
331+
)
300332
# Join skill score values from the pivot table.
301333
join_cols = group_by_strings + ["configuration_name"]
302334
sdf = sdf.join(

src/teehr/metrics/probabilistic_funcs.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,89 @@ def ensemble_crps_inner(
104104
)
105105

106106
return ensemble_crps_inner
107+
108+
109+
def _get_brier_score_inputs(pivoted_dict: dict,
110+
threshold: float) -> dict:
111+
"""Obtain inputs for scoringrules.brier_score from pivoted dict."""
112+
# get quantile flow
113+
p = pivoted_dict['primary']
114+
q_threshold = np.quantile(p, threshold)
115+
116+
# get binary outcomes of observed exceeding threshold
117+
binary_p = np.where(p >= q_threshold, 1, 0)
118+
119+
# get fraction of ensemble members exceeding threshold for each time step
120+
s = pivoted_dict['secondary']
121+
binary_s = np.where(s >= q_threshold, 1, 0)
122+
if len(binary_s.shape) == 1:
123+
# only one ensemble member
124+
frac_exceeds_s = binary_s
125+
else:
126+
frac_exceeds_s = np.mean(binary_s, axis=1)
127+
128+
# assemble inputs dict
129+
brier_score_inputs = {
130+
'primary': binary_p,
131+
'secondary': frac_exceeds_s
132+
}
133+
134+
return brier_score_inputs
135+
136+
137+
def ensemble_brier_score(model: MetricsBasemodel) -> Callable:
138+
"""Create the Brier Score ensemble metric function."""
139+
logger.debug("Building the Brier Score ensemble metric func.")
140+
141+
def ensemble_brier_score_inner(
142+
p: pd.Series,
143+
s: pd.Series,
144+
members: pd.Series,
145+
) -> float:
146+
"""Create a wrapper around scoringrules brier_score.
147+
148+
Parameters
149+
----------
150+
p : pd.Series
151+
The primary values.
152+
s : pd.Series
153+
The secondary values.
154+
members : pd.Series
155+
The member IDs.
156+
threshold : float
157+
The threshold for the Brier Score calculation.
158+
159+
Returns
160+
-------
161+
float
162+
The mean Brier Score for the ensemble, either as a single value
163+
or array of values.
164+
"""
165+
# lazy load scoringrules
166+
import scoringrules as sr
167+
168+
# p, s, value_time = _transform(p, s, model, value_time)
169+
# pivoted_dict = _pivot_by_value_time(p, s, value_time)
170+
pivoted_dict = _pivot_by_member(p, s, members)
171+
172+
bs_inputs = _get_brier_score_inputs(
173+
pivoted_dict,
174+
model.threshold
175+
)
176+
177+
if model.summary_func is not None:
178+
return model.summary_func(
179+
sr.brier_score(
180+
bs_inputs["primary"],
181+
bs_inputs["secondary"],
182+
backend=model.backend
183+
)
184+
)
185+
else:
186+
return sr.brier_score(
187+
bs_inputs["primary"],
188+
bs_inputs["secondary"],
189+
backend=model.backend
190+
)
191+
192+
return ensemble_brier_score_inner

src/teehr/models/metrics/metric_attributes.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,15 @@
255255
"requires_threshold_field": False,
256256
}
257257

258+
BS_ENSEMBLE_ATTRS = {
259+
"short_name": "brier_score_ensemble",
260+
"display_name": "Brier Score - Ensemble",
261+
"category": mc.Probabilistic,
262+
"value_range": [0.0, 1.0],
263+
"optimal_value": 0.0,
264+
"requires_threshold_field": False,
265+
}
266+
258267
FDC_SLOPE_ATTRS = {
259268
"short_name": "fdc_slope",
260269
"display_name": "Flow Duration Curve Slope",

src/teehr/models/metrics/probabilistic_models.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,41 @@ class CRPS(ProbabilisticBasemodel):
4646
attrs: Dict = Field(default=tma.CRPS_ENSEMBLE_ATTRS, frozen=True)
4747

4848

49+
class BrierScore(ProbabilisticBasemodel):
50+
"""Brier Score for ensemble probabilistic forecasts.
51+
52+
Parameters
53+
----------
54+
threshold : float
55+
The threshold to use for binary event definition.
56+
backend : str
57+
The backend to use, by default "numba". Can be ("numba" or "numpy").
58+
summary_func : Callable
59+
The function to apply to the results, by default np.mean.
60+
output_field_name : str
61+
The output field name, by default "mean_brier_score".
62+
func : Callable
63+
The function to apply to the data, by default
64+
:func:`probabilistic_funcs.ensemble_brier_score`.
65+
input_field_names : Union[str, StrEnum, List[Union[str, StrEnum]]]
66+
The input field names, by default
67+
["primary_value", "secondary_value", "member"].
68+
attrs : Dict
69+
The static attributes for the metric.
70+
"""
71+
72+
threshold: float = Field(default=0.75)
73+
transform: TransformEnum = Field(default=None)
74+
backend: str = Field(default="numba")
75+
output_field_name: str = Field(default="mean_brier_score")
76+
func: Callable = Field(probabilistic_funcs.ensemble_brier_score, frozen=True)
77+
summary_func: Union[Callable, None] = Field(default=None)
78+
input_field_names: Union[str, StrEnum, List[Union[str, StrEnum]]] = Field(
79+
default=["primary_value", "secondary_value", "member"]
80+
)
81+
attrs: Dict = Field(default=tma.BS_ENSEMBLE_ATTRS, frozen=True)
82+
83+
4984
class ProbabilisticMetrics:
5085
"""Define and customize probalistic metrics.
5186
@@ -59,3 +94,4 @@ class ProbabilisticMetrics:
5994
"""
6095

6196
CRPS = CRPS
97+
BrierScore = BrierScore
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import teehr
2+
from pathlib import Path
3+
import pandas as pd
4+
from shapely.geometry import Point
5+
import geopandas as gpd
6+
7+
TEST_STUDY_DATA_DIR_V0_5 = Path("tests", "data", "v0_5_ensemble_study")
8+
9+
10+
def setup_v0_5_ensemble_study(tmpdir):
11+
"""Create a test evaluation with ensemble forecasts using teehr."""
12+
# define pathing
13+
location_xwalk_path = TEST_STUDY_DATA_DIR_V0_5 / "location_crosswalks.parquet"
14+
primary_ts_path = TEST_STUDY_DATA_DIR_V0_5 / "primary_timeseries.parquet"
15+
secondary_ts_path = TEST_STUDY_DATA_DIR_V0_5 / "secondary_timeseries.parquet"
16+
configurations_path = TEST_STUDY_DATA_DIR_V0_5 / "configurations.parquet"
17+
variables_path = TEST_STUDY_DATA_DIR_V0_5 / "variables.parquet"
18+
19+
# initialize evaluation
20+
ev = teehr.Evaluation(dir_path=tmpdir)
21+
ev.enable_logging()
22+
ev.clone_template()
23+
24+
# create locations
25+
location_dict = {
26+
'id': 'obs-GILN6',
27+
'name': 'Schoharie Creek at Schoharie',
28+
'geometry': Point(-74.45043182373047, 42.397300720214844),
29+
}
30+
gdf = gpd.GeoDataFrame(
31+
[location_dict],
32+
geometry='geometry',
33+
crs="EPSG:4269"
34+
)
35+
ev.locations.load_dataframe(df=gdf, write_mode="overwrite")
36+
37+
# load crosswalk
38+
ev.location_crosswalks.load_parquet(
39+
in_path=location_xwalk_path
40+
)
41+
42+
# add configurations
43+
df = pd.read_parquet(configurations_path)
44+
for _, row in df.iterrows():
45+
ev.configurations.add(
46+
teehr.Configuration(
47+
name=row["name"],
48+
type=row["type"],
49+
description=row["description"]
50+
)
51+
)
52+
53+
# add variables table
54+
df = pd.read_parquet(variables_path)
55+
for _, row in df.iterrows():
56+
ev.variables.add(
57+
teehr.Variable(
58+
name=row["name"],
59+
long_name=row["long_name"],
60+
)
61+
)
62+
63+
# load primary timeseries
64+
ev.primary_timeseries.load_parquet(
65+
in_path=primary_ts_path
66+
)
67+
68+
# load secondary timeseries
69+
ev.secondary_timeseries.load_parquet(
70+
in_path=secondary_ts_path
71+
)
72+
73+
# create JTS
74+
ev.joined_timeseries.create(add_attrs=False, execute_scripts=True)
75+
76+
return ev
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:494922481bbc395a13192449163db43d31ea7eae1b3fb519bb6370b816fb5fc6
3+
size 3252
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:c29d9bf222a29f0b61f08c290350b0692a9a13e4e98364073cc5a21dec412977
3+
size 2598
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:be10143a8baef8e794384d0ff4fe735afd0b89249fff66bc28a9a24edf2c745d
3+
size 663557
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:165184dd98de07d5be1b039a116329d1091f3bc9efb5c9d1124a58ca4a01ab84
3+
size 1665132
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
version https://git-lfs.github.com/spec/v1
2+
oid sha256:5e007ba5f8f4f11a9a392a0ed1b9f1b39ad48accd8ced199b8ac3129ce43ba95
3+
size 2574

0 commit comments

Comments
 (0)