@@ -778,14 +778,15 @@ def sensitivity_benchmark(self, benchmarking_set, fit_args=None):
778778
779779 def aggregate (self , aggregation = "simple" ):
780780 if not isinstance (aggregation , str ):
781- raise TypeError (
782- "aggregation must be a string. " f"{ str (aggregation )} of type { type (aggregation )} was passed."
783- )
781+ raise TypeError ("aggregation must be a string. " f"{ str (aggregation )} of type { type (aggregation )} was passed." )
784782 valid_aggregations = ["simple" ]
785783 if aggregation not in valid_aggregations :
786- raise ValueError (
787- f"aggregation must be one of { valid_aggregations } . " f"{ str (aggregation )} was passed."
788- )
784+ raise ValueError (f"aggregation must be one of { valid_aggregations } . " f"{ str (aggregation )} was passed." )
785+ if self .framework is None :
786+ raise ValueError ("Apply fit() before aggregate()." )
787+
788+ if aggregation == "simple" :
789+ pass
789790 pass
790791
791792 def _fit_model (self , i_gt , n_jobs_cv = None , store_predictions = True , store_models = False , external_predictions_dict = None ):
@@ -888,9 +889,7 @@ def _rename_external_predictions(self, external_predictions):
888889 return ext_pred_dict
889890
890891 def _calc_nuisance_loss (self ):
891- nuisance_loss = {
892- learner : np .full ((self .n_rep , self .n_gt_atts ), np .nan ) for learner in self .modellist [0 ].params_names
893- }
892+ nuisance_loss = {learner : np .full ((self .n_rep , self .n_gt_atts ), np .nan ) for learner in self .modellist [0 ].params_names }
894893 for i_model , model in enumerate (self .modellist ):
895894 for learner in self .modellist [0 ].params_names :
896895 for i_rep in range (self .n_rep ):
0 commit comments