@@ -788,22 +788,43 @@ class RelativeEconomicValue(base.PerVariableMetric):
788788 """Relative economic value.
789789
790790 This metric assumes that the targets are a binary and the predictions
791- are probabilities between in [0, 1]. It computes REV across all possible
792- decision thresholds for a given ensemble size.
791+ are probabilities between in [0, 1].
792+
793+ Defaults to 51 cost-loss ratios spaced uniformly on a log scale from 0.005 to
794+ 1, unless overridden by setting cost_loss_ratios directly.
795+
796+ By default, REV is computed across all possible decision thresholds for a
797+ given ensemble size. However, it is also possible to specify the desired
798+ probability thresholds. This is useful if you only want to compute REV for a
799+ subset of all possible probability thresholds.
800+
801+ TODO(tomandersson, matthjw): Add a helper method to ensure overridden
802+ probability thresholds are a subset of the between-each-ensemble-member-count
803+ thresholds, e.g. by snapping-to-nearest then deduplicating.
793804 """
794805
795806 def __init__ (
796- self , ensemble_size : int , cost_loss_ratios : Optional [np .ndarray ] = None
807+ self ,
808+ ensemble_size : Optional [int ] = None ,
809+ cost_loss_ratios : Optional [np .ndarray ] = None ,
810+ probability_thresholds : Optional [np .ndarray ] = None ,
811+ statistic_suffix : Optional [str ] = None ,
797812 ):
798813
799- thresholds = (np .arange (ensemble_size ) + 0.5 ) / ensemble_size
800-
801- self ._thresholds = xr .DataArray (
802- thresholds , dims = ['threshold' ], coords = {'threshold' : thresholds }
803- )
804- # TODO(tomandersson): Make this configurable when thresholds themselves are
805- # configurable.
806- self ._unique_name_suffix = 'all_thresholds_for_ensemble_size'
814+ if ensemble_size is None and probability_thresholds is None :
815+ raise ValueError (
816+ 'Either ensemble_size or probability_thresholds must be specified.'
817+ )
818+ if probability_thresholds is not None and ensemble_size is not None :
819+ raise ValueError (
820+ 'Only one of ensemble_size or probability_thresholds must be'
821+ ' specified.'
822+ )
823+ if probability_thresholds is not None and statistic_suffix is None :
824+ raise ValueError (
825+ 'If probability_thresholds is specified, statistic_suffix must be'
826+ ' specified.'
827+ )
807828
808829 if cost_loss_ratios is None :
809830 cost_loss_ratios = np .geomspace (0.005 , 1 , 51 )[:- 1 ]
@@ -814,6 +835,21 @@ def __init__(
814835 coords = {'cost_loss_ratio' : cost_loss_ratios },
815836 )
816837
838+ self ._thresholds = probability_thresholds
839+ if self ._thresholds is None :
840+ self ._thresholds = (np .arange (ensemble_size ) + 0.5 ) / ensemble_size
841+ if statistic_suffix is None :
842+ statistic_suffix = 'all_thresholds_for_ensemble_size'
843+ if not np .all (self ._thresholds >= 0.0 ) or not np .all (
844+ self ._thresholds <= 1.0
845+ ):
846+ raise ValueError (
847+ 'Probability thresholds must be in [0, 1], got'
848+ f' { self ._thresholds = } .'
849+ )
850+
851+ self ._unique_name_suffix = statistic_suffix or ''
852+
817853 @property
818854 def statistics (self ) -> Mapping [str , base .Statistic ]:
819855 binarize = wrappers .ContinuousToBinary (
0 commit comments