@@ -90,45 +90,12 @@ def compare(
9090 leaderboard_kwargs : dict | None = None ,
9191 remove_imputed : bool = False ,
9292):
93- df_results = df_results .copy ()
94-
95- if isinstance (only_valid_tasks , str ):
96- only_valid_tasks = [only_valid_tasks ]
97- if isinstance (only_valid_tasks , list ):
98- for filter_method in only_valid_tasks :
99- # Filter to tasks present in a specific method
100- df_filter = df_results [df_results ["method" ] == filter_method ]
101- if "imputed" in df_filter .columns :
102- df_filter = df_filter [df_filter ["imputed" ] != True ]
103- assert len (df_filter ) != 0 , \
104- (f"No method named '{ filter_method } ' remains to filter to!\n "
105- f"Available tasks: { list (df_results ['method' ].unique ())} " )
106- df_results = filter_to_valid_tasks (
107- df_to_filter = df_results ,
108- df_filter = df_filter ,
109- )
110-
111- if "method_type" not in df_results .columns :
112- df_results ["method_type" ] = "baseline"
113- if "method_subtype" not in df_results .columns :
114- df_results ["method_subtype" ] = np .nan
115- if "config_type" not in df_results .columns :
116- df_results ["config_type" ] = np .nan
117- if "imputed" not in df_results .columns :
118- df_results ["imputed" ] = False
119-
120- if isinstance (fillna , str ):
121- fillna = df_results [df_results ["method" ] == fillna ]
122- if fillna is not None :
123- df_results = TabArenaContext .fillna_metrics (
124- df_to_fill = df_results ,
125- df_fillna = fillna ,
126- )
127-
128- if remove_imputed :
129- methods_imputed = df_results .groupby ("method" )["imputed" ].sum ()
130- methods_imputed = list (methods_imputed [methods_imputed > 0 ].index )
131- df_results = df_results [~ df_results ["method" ].isin (methods_imputed )]
93+ df_results = prepare_data (
94+ df_results = df_results ,
95+ only_valid_tasks = only_valid_tasks ,
96+ fillna = fillna ,
97+ remove_imputed = remove_imputed ,
98+ )
13299
133100 if score_on_val :
134101 error_col = "metric_error_val"
@@ -173,6 +140,55 @@ def is_in(dataset: str, fold: int) -> bool:
173140 return df_filtered
174141
175142
143+ def prepare_data (
144+ df_results : pd .DataFrame ,
145+ only_valid_tasks : str | list [str ] | None = None ,
146+ fillna : str | pd .DataFrame | None = None ,
147+ remove_imputed : bool = False ,
148+ ) -> pd .DataFrame :
149+ df_results = df_results .copy ()
150+
151+ if isinstance (only_valid_tasks , str ):
152+ only_valid_tasks = [only_valid_tasks ]
153+ if isinstance (only_valid_tasks , list ):
154+ for filter_method in only_valid_tasks :
155+ # Filter to tasks present in a specific method
156+ df_filter = df_results [df_results ["method" ] == filter_method ]
157+ if "imputed" in df_filter .columns :
158+ df_filter = df_filter [df_filter ["imputed" ] != True ]
159+ assert len (df_filter ) != 0 , \
160+ (f"No method named '{ filter_method } ' remains to filter to!\n "
161+ f"Available tasks: { list (df_results ['method' ].unique ())} " )
162+ df_results = filter_to_valid_tasks (
163+ df_to_filter = df_results ,
164+ df_filter = df_filter ,
165+ )
166+
167+ if "method_type" not in df_results .columns :
168+ df_results ["method_type" ] = "baseline"
169+ if "method_subtype" not in df_results .columns :
170+ df_results ["method_subtype" ] = np .nan
171+ if "config_type" not in df_results .columns :
172+ df_results ["config_type" ] = np .nan
173+ if "imputed" not in df_results .columns :
174+ df_results ["imputed" ] = False
175+
176+ if isinstance (fillna , str ):
177+ fillna = df_results [df_results ["method" ] == fillna ]
178+ if fillna is not None :
179+ df_results = TabArenaContext .fillna_metrics (
180+ df_to_fill = df_results ,
181+ df_fillna = fillna ,
182+ )
183+
184+ if remove_imputed :
185+ methods_imputed = df_results .groupby ("method" )["imputed" ].sum ()
186+ methods_imputed = list (methods_imputed [methods_imputed > 0 ].index )
187+ df_results = df_results [~ df_results ["method" ].isin (methods_imputed )]
188+
189+ return df_results
190+
191+
176192def subset_tasks (df_results : pd .DataFrame , subset : list [str ], folds : list [int ] = None ) -> pd .DataFrame :
177193 from tabarena .nips2025_utils .fetch_metadata import load_task_metadata
178194
0 commit comments