Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion weatherbenchX/aggregation.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def aggregation_fn(
# Can't bin based on dims that aren't present as evaluation unit dims:
return None

return xr.dot(stat, *weights, *bin_masks, dims=reduce_dims_set)
return xr.dot(stat, *weights, *bin_masks, dim=reduce_dims_set)

def aggregate_statistics(
self,
Expand Down
2 changes: 1 addition & 1 deletion weatherbenchX/beam_pipeline.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We will update the xarray requirement to >=2025.07 which should make this change unnecessary since unfortunately this change won't work for us internally.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated has now been merged, so the tests should run even without this change.

Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def process(self, metrics: xr.Dataset) -> None:
"""
logging.info('WriteMetrics inputs: %s', metrics)
with fsspec.open(self.out_path, 'wb', auto_mkdir=True) as f:
f.write(metrics.to_netcdf())
metrics.to_netcdf(f, engine="netcdf4")
return None


Expand Down
169 changes: 168 additions & 1 deletion weatherbenchX/metrics/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ def _compute_seeps_per_variable(
scoring_matrix = scoring_matrix.compute()

# Take dot product
result = xr.dot(out, scoring_matrix, dims=('forecast_cat', 'truth_cat'))
result = xr.dot(out, scoring_matrix, dim=('forecast_cat', 'truth_cat'))

# Mask out p1 thresholds
mask = (p1 >= min_p1) & (p1 <= max_p1)
Expand Down Expand Up @@ -518,3 +518,170 @@ def _values_from_mean_statistics_per_variable(
return statistic_values['TruePositives'] / (
statistic_values['TruePositives'] + statistic_values['FalsePositives']
)


class FalseAlarmRate(base.PerVariableMetric):
"""False alarm rate (F).

Also called probability of false detection.

F = FP / (FP + TN).

Definition in:
Donaldson, R. J., Dyer, R. M., & Kraus, M. J. (1975). An objective evaluator of techniques for
predicting severe weather events. In Preprints, Ninth Conf. on Severe Local Storms, Norman,
OK, Amer. Meteor. Soc (Vol. 321326).
"""

@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'FalsePositives': FalsePositives(),
'TrueNegatives': TrueNegatives(),
}

def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return statistic_values['FalsePositives'] / (
statistic_values['FalsePositives'] + statistic_values['TrueNegatives']
)


class TrueNegativeRate(base.PerVariableMetric):
"""True negative rate (TNR).

Also called specificity.

TNR = TN / (FP + TN).
"""

@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'FalsePositives': FalsePositives(),
'TrueNegatives': TrueNegatives(),
}

def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return statistic_values['TrueNegatives'] / (
statistic_values['FalsePositives'] + statistic_values['TrueNegatives']
)


class PeirceSkillScore(base.PerVariableMetric):
"""Peirce’s skill score (PSS).

Also called Hanssen and Kuipers discriminant

H = TP / (TP + FN)
F = FP / (FP + TN)

PSS = H - F = TP / (TP + FN) - FP / (FP + TN)

Definition in:
Peirce, C. S. (1884). The Numerical Measure of the Success of Predictions. Science, ns-4(93),
453–454. https://doi.org/10.1126/science.ns-4.93.453.b
"""

@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
'FalseNegatives': FalseNegatives(),
'TrueNegatives': TrueNegatives(),
}

def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return statistic_values['TruePositives'] / (
statistic_values['TruePositives'] + statistic_values['FalseNegatives']
) - statistic_values['FalsePositives'] / (
statistic_values['FalsePositives'] + statistic_values['TrueNegatives']
)


class OddsRatioSkillScore(base.PerVariableMetric):
"""Odds ratio skill score (ORSS).

Also called Yule’s Q.

ORSS = (TP * TN - FP * FN) / (TP * TN + FP * FN)

Definition in:
Stephenson, D. B. (2000). Use of the “Odds Ratio” for Diagnosing Forecast Skill.
Retrieved from https://journals.ametsoc.org/view/journals/wefo/15/2/1520-0434_2000_015_0221_uotorf_2_0_co_2.xml
"""

@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
'FalseNegatives': FalseNegatives(),
'TrueNegatives': TrueNegatives(),
}

def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return (
statistic_values['TruePositives'] * statistic_values['TrueNegatives']
- statistic_values['FalsePositives'] * statistic_values['FalseNegatives']
) / (
statistic_values['TruePositives'] * statistic_values['TrueNegatives']
+ statistic_values['FalsePositives'] * statistic_values['FalseNegatives']
)


class ExtremalDependenceIndex(base.PerVariableMetric):
"""Extremal dependence index (EDI).

H = TP / (TP + FN)
F = FP / (FP + TN)

EDI = (lnF - lnH) / (lnF + lnH)

Definition in:
Ferro, C. A. T., & Stephenson, D. B. (2011). Extremal Dependence Indices: Improved Verification Measures
for Deterministic Forecasts of Rare Binary Events. Weather and Forecasting, 26(5), 699–713.
https://doi.org/10.1175/WAF-D-10-05030.1
"""

@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
'FalseNegatives': FalseNegatives(),
'TrueNegatives': TrueNegatives(),
}

def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""

hit_rate = statistic_values['TruePositives'] / (
statistic_values['TruePositives'] + statistic_values['FalseNegatives'])
false_alarm_rate = statistic_values['FalsePositives'] / (
statistic_values['FalsePositives'] + statistic_values['TrueNegatives']
)
return (
np.log(false_alarm_rate) - np.log(hit_rate)
) / (
np.log(false_alarm_rate) + np.log(hit_rate)
)
Loading
Loading