Skip to content

Commit bbb1bc7

Browse files
author
WeatherBenchX authors
committed
No public description
PiperOrigin-RevId: 824614898
1 parent 16102ba commit bbb1bc7

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

weatherbenchX/metrics/categorical.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,7 @@ class Reliability(base.PerVariableMetric):
532532
predictions. It will automatically apply binning to the predictions into 10
533533
equal-width bins, assuming the predictions are in [0, 1]. You can modify the
534534
number of bins and the bin edges by passing in `bin_values` and `bin_dim`. For
535-
each bin of predicted probabilities, the metric will compute the probability
535+
each bin of predicted probabilities, the metric will compute the probability
536536
of the positive class according to the ground truth.
537537
"""
538538

@@ -552,16 +552,19 @@ def __init__(
552552
1.
553553
),
554554
bin_dim: str = 'reliability_bin',
555+
statistic_suffix: str | None = None,
555556
):
556557
self._bin_values = bin_values
557558
self._bin_dim = bin_dim
559+
self._unique_name_suffix = statistic_suffix
558560

559561
@property
560562
def statistics(self) -> Mapping[str, base.Statistic]:
561563
binned_prediction_wrapper = wrappers.ContinuousToBins(
562564
which='predictions',
563565
bin_values=self._bin_values,
564566
bin_dim=self._bin_dim,
567+
unique_name_suffix=self._unique_name_suffix,
565568
)
566569
return {
567570
'TruePositives': wrappers.WrappedStatistic(

weatherbenchX/metrics/wrappers_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,10 @@ class ContinuousToBinaryTest(parameterized.TestCase):
2727
def test_constant_threshold(self):
2828
target = test_utils.mock_target_data(random=True)
2929
ctb = wrappers.ContinuousToBinary(
30-
which='both', threshold_value=0.5, threshold_dim='threshold'
30+
which='both',
31+
threshold_value=0.5,
32+
threshold_dim='threshold',
33+
unique_name_suffix='test',
3134
)
3235

3336
x = target.geopotential

0 commit comments

Comments
 (0)