@@ -19,6 +19,9 @@ def __init__(self, approximators: dict[str, Approximator], **kwargs):
1919 self .num_approximators = len (self .approximators )
2020
2121 def build (self , data_shapes : dict [str , tuple [int ] | dict [str , dict ]]) -> None :
22+ # Remove the ensemble dimension from data_shapes. This expects data_shapes are the shapes of a
23+ # batch of training data, where the second axis corresponds to different approximators.
24+ data_shapes = {k : v [:1 ] + v [2 :] for k , v in data_shapes .items ()}
2225 for approximator in self .approximators .values ():
2326 approximator .build (data_shapes )
2427
@@ -30,19 +33,32 @@ def compute_metrics(
3033 sample_weight : Tensor = None ,
3134 stage : str = "training" ,
3235 ) -> dict [str , dict [str , Tensor ]]:
36+ # Prepare empty dict for metrics
3337 metrics = {}
34- for approx_name , approximator in self .approximators .items ():
35- # TODO: actually do the slicing
36- inference_variables_slice = inference_variables
37- inference_conditions_slice = inference_conditions
38- summary_variables_slice = summary_variables
39- sample_weight_slice = sample_weight
38+
39+ # Define the variable slices as None (default) or respective input
40+ _inference_variables = inference_variables
41+ _inference_conditions = inference_conditions
42+ _summary_variables = summary_variables
43+ _sample_weight = sample_weight
44+
45+ for i , (approx_name , approximator ) in enumerate (self .approximators .items ()):
46+ # During training each approximator receives its own separate slice
47+ if stage == "training" :
48+ # Pick out the correct slice for each ensemble member
49+ _inference_variables = inference_variables [:, i ]
50+ if inference_conditions is not None :
51+ _inference_conditions = inference_conditions [:, i ]
52+ if summary_variables is not None :
53+ _summary_variables = summary_variables [:, i ]
54+ if sample_weight is not None :
55+ _sample_weight = sample_weight [:, i ]
4056
4157 metrics [approx_name ] = approximator .compute_metrics (
42- inference_variables = inference_variables_slice ,
43- inference_conditions = inference_conditions_slice ,
44- summary_variables = summary_variables_slice ,
45- sample_weight = sample_weight_slice ,
58+ inference_variables = _inference_variables ,
59+ inference_conditions = _inference_conditions ,
60+ summary_variables = _summary_variables ,
61+ sample_weight = _sample_weight ,
4662 stage = stage ,
4763 )
4864
@@ -51,7 +67,6 @@ def compute_metrics(
5167 for approx_name in metrics .keys ():
5268 for metric_key , value in metrics [approx_name ].items ():
5369 joint_metrics [f"{ approx_name } /{ metric_key } " ] = value
54-
5570 metrics = joint_metrics
5671
5772 # Sum over losses
0 commit comments