@@ -104,3 +104,89 @@ def ensemble_crps_inner(
104104 )
105105
106106 return ensemble_crps_inner
107+
108+
109+ def _get_brier_score_inputs (pivoted_dict : dict ,
110+ threshold : float ) -> dict :
111+ """Obtain inputs for scoringrules.brier_score from pivoted dict."""
112+ # get quantile flow
113+ p = pivoted_dict ['primary' ]
114+ q_threshold = np .quantile (p , threshold )
115+
116+ # get binary outcomes of observed exceeding threshold
117+ binary_p = np .where (p >= q_threshold , 1 , 0 )
118+
119+ # get fraction of ensemble members exceeding threshold for each time step
120+ s = pivoted_dict ['secondary' ]
121+ binary_s = np .where (s >= q_threshold , 1 , 0 )
122+ if len (binary_s .shape ) == 1 :
123+ # only one ensemble member
124+ frac_exceeds_s = binary_s
125+ else :
126+ frac_exceeds_s = np .mean (binary_s , axis = 1 )
127+
128+ # assemble inputs dict
129+ brier_score_inputs = {
130+ 'primary' : binary_p ,
131+ 'secondary' : frac_exceeds_s
132+ }
133+
134+ return brier_score_inputs
135+
136+
137+ def ensemble_brier_score (model : MetricsBasemodel ) -> Callable :
138+ """Create the Brier Score ensemble metric function."""
139+ logger .debug ("Building the Brier Score ensemble metric func." )
140+
141+ def ensemble_brier_score_inner (
142+ p : pd .Series ,
143+ s : pd .Series ,
144+ members : pd .Series ,
145+ ) -> float :
146+ """Create a wrapper around scoringrules brier_score.
147+
148+ Parameters
149+ ----------
150+ p : pd.Series
151+ The primary values.
152+ s : pd.Series
153+ The secondary values.
154+ members : pd.Series
155+ The member IDs.
156+ threshold : float
157+ The threshold for the Brier Score calculation.
158+
159+ Returns
160+ -------
161+ float
162+ The mean Brier Score for the ensemble, either as a single value
163+ or array of values.
164+ """
165+ # lazy load scoringrules
166+ import scoringrules as sr
167+
168+ # p, s, value_time = _transform(p, s, model, value_time)
169+ # pivoted_dict = _pivot_by_value_time(p, s, value_time)
170+ pivoted_dict = _pivot_by_member (p , s , members )
171+
172+ bs_inputs = _get_brier_score_inputs (
173+ pivoted_dict ,
174+ model .threshold
175+ )
176+
177+ if model .summary_func is not None :
178+ return model .summary_func (
179+ sr .brier_score (
180+ bs_inputs ["primary" ],
181+ bs_inputs ["secondary" ],
182+ backend = model .backend
183+ )
184+ )
185+ else :
186+ return sr .brier_score (
187+ bs_inputs ["primary" ],
188+ bs_inputs ["secondary" ],
189+ backend = model .backend
190+ )
191+
192+ return ensemble_brier_score_inner
0 commit comments