@@ -331,22 +331,9 @@ def test_leaderboard(
331331):
332332 # Comprehensive test tasks a substantial amount of time, manually set if
333333 # required.
334- MAX_COMBO_SIZE_FOR_INCLUDE_PARAM = 2 # [0, len(valid_columns) + 1]
334+ MAX_COMBO_SIZE_FOR_INCLUDE_PARAM = 3 # [0, len(valid_columns) + 1]
335+ column_types = AutoSklearnEstimator ._leaderboard_columns ()
335336
336- X_train , Y_train , _ , _ = putil .get_dataset (dataset_name )
337- model = estimator_type (
338- time_left_for_this_task = 30 ,
339- per_run_time_limit = 5 ,
340- tmp_folder = tmp_dir ,
341- seed = 1
342- )
343- model .fit (X_train , Y_train )
344-
345- column_types = {
346- 'all' : AutoSklearnEstimator ._leaderboard_columns ['all' ],
347- 'simple' : AutoSklearnEstimator ._leaderboard_columns ['simple' ],
348- 'detailed' : AutoSklearnEstimator ._leaderboard_columns ['all' ]
349- }
350337 # Create a dict of all possible param values for each param
351338 # with some invalid one's of the incorrect type
352339 include_combinations = itertools .chain (
@@ -357,7 +344,7 @@ def test_leaderboard(
357344 'detailed' : [True , False ],
358345 'ensemble_only' : [True , False ],
359346 'top_k' : [- 10 , 0 , 1 , 10 , 'all' ],
360- 'sort_by' : [column_types ['all' ], 'invalid' ],
347+ 'sort_by' : [* column_types ['all' ], 'invalid' ],
361348 'sort_order' : ['ascending' , 'descending' , 'auto' , 'invalid' , None ],
362349 'include' : itertools .chain ([None , 'invalid' , 'type' ], include_combinations ),
363350 }
@@ -368,7 +355,19 @@ def test_leaderboard(
368355 for param_values in itertools .product (* valid_params .values ())
369356 )
370357
358+ X_train , Y_train , _ , _ = putil .get_dataset (dataset_name )
359+ model = estimator_type (
360+ time_left_for_this_task = 30 ,
361+ per_run_time_limit = 5 ,
362+ tmp_folder = tmp_dir ,
363+ seed = 1
364+ )
365+ model .fit (X_train , Y_train )
366+
371367 for params in params_generator :
368+ # Convert from iterator to solid list
369+ if params ['include' ] is not None and not isinstance (params ['include' ], str ):
370+ params ['include' ] = list (params ['include' ])
372371
373372 # Invalid top_k should raise an error, is a positive int or 'all'
374373 if not (params ['top_k' ] == 'all' or params ['top_k' ] > 0 ):
@@ -385,26 +384,32 @@ def test_leaderboard(
385384 with pytest .raises (ValueError ):
386385 model .leaderboard (** params )
387386
388- # Invalid include item in a list
389- elif params ['include' ] is not None :
390- # Crash if just a str but invalid column
391- if (
392- isinstance (params ['include' ], str )
393- and params ['include' ] not in column_types ['all' ]
394- ):
395- with pytest .raises (ValueError ):
396- model .leaderboard (** params )
397- # Crash if list but contains invalid column
398- elif (
399- not isinstance (params ['include' ], str )
400- and len (set (params ['include' ]) - set (column_types ['all' ])) != 0
401- ):
402- with pytest .raises (ValueError ):
403- model .leaderboard (** params )
387+ # include is single str but not valid
388+ elif (
389+ isinstance (params ['include' ], str )
390+ and params ['include' ] not in column_types ['all' ]
391+ ):
392+ with pytest .raises (ValueError ):
393+ model .leaderboard (** params )
394+
395+ # Crash if include is list but contains invalid column
396+ elif (
397+ isinstance (params ['include' ], list )
398+ and len (set (params ['include' ]) - set (column_types ['all' ])) != 0
399+ ):
400+ with pytest .raises (ValueError ):
401+ model .leaderboard (** params )
402+
403+ # Can't have just model_id, in both single str and list case
404+ elif (
405+ params ['include' ] == 'model_id'
406+ or params ['include' ] == ['model_id' ]
407+ ):
408+ with pytest .raises (ValueError ):
409+ model .leaderboard (** params )
404410
405- # Should run without an error if all params are valid
411+ # Else all valid combinations should be validated
406412 else :
407- # Validate the outputs
408413 leaderboard = model .leaderboard (** params )
409414
410415 # top_k should never be less than the rows given back
@@ -413,14 +418,23 @@ def test_leaderboard(
413418 assert params ['top_k' ] >= len (leaderboard )
414419
415420 # Check the right columns are present and in the right order
416- # The id is set as the index but is not included in pandas columns
421+ # The model_id is set as the index, not included in pandas columns
417422 columns = list (leaderboard .columns )
423+
424+ def exclude (lst , s ):
425+ return [x for x in lst if x != s ]
426+
418427 if params ['include' ] is not None :
419- assert columns == list (params ['include' ])
428+ # Include with only single str should be the only column
429+ if isinstance (params ['include' ], str ):
430+ assert params ['include' ] in columns and len (columns ) == 1
431+ # Include as a list should have all the columns without model_id
432+ else :
433+ assert columns == exclude (params ['include' ], 'model_id' )
420434 elif params ['detailed' ]:
421- assert columns == column_types ['detailed' ]
435+ assert columns == exclude ( column_types ['detailed' ], 'model_id' )
422436 else :
423- assert columns == column_types ['simple' ]
437+ assert columns == exclude ( column_types ['simple' ], 'model_id' )
424438
425439 # Ensure that if it's ensemble only
426440 # Can only check if 'ensemble_weight' is present
0 commit comments