@@ -660,24 +660,45 @@ def test_fail_if_feat_type_on_pandas_input(backend, dask_client):
660660
661661
662662@pytest .mark .parametrize (
663- 'memory_limit,task' ,
663+ 'memory_limit,precision, task' ,
664664 [
665- (memory_limit , task )
665+ (memory_limit , precision , task )
666666 for task in itertools .chain (CLASSIFICATION_TASKS , REGRESSION_TASKS )
667- for memory_limit in (1 , 10 , None )
667+ for precision in (float , np .float32 , np .float64 , np .float128 )
668+ for memory_limit in (1 , 100 , None )
668669 ]
669670)
670- def test_subsample_if_too_large (memory_limit , task ):
671+ def test_subsample_if_too_large (memory_limit , precision , task ):
671672 fixture = {
672- BINARY_CLASSIFICATION : {1 : 436 , 10 : 569 , None : 569 },
673- MULTICLASS_CLASSIFICATION : {1 : 204 , 10 : 1797 , None : 1797 },
674- MULTILABEL_CLASSIFICATION : {1 : 204 , 10 : 1797 , None : 1797 },
675- REGRESSION : {1 : 1310 , 10 : 1326 , None : 1326 },
676- MULTIOUTPUT_REGRESSION : {1 : 1310 , 10 : 1326 , None : 1326 }
673+ BINARY_CLASSIFICATION : {
674+ 1 : {float : 1310 , np .float32 : 2621 , np .float64 : 1310 , np .float128 : 655 },
675+ 100 : {float : 12000 , np .float32 : 12000 , np .float64 : 12000 , np .float128 : 12000 },
676+ None : {float : 12000 , np .float32 : 12000 , np .float64 : 12000 , np .float128 : 12000 },
677+ },
678+ MULTICLASS_CLASSIFICATION : {
679+ 1 : {float : 204 , np .float32 : 409 , np .float64 : 204 , np .float128 : 102 },
680+ 100 : {float : 1797 , np .float32 : 1797 , np .float64 : 1797 , np .float128 : 1797 },
681+ None : {float : 1797 , np .float32 : 1797 , np .float64 : 1797 , np .float128 : 1797 },
682+ },
683+ MULTILABEL_CLASSIFICATION : {
684+ 1 : {float : 204 , np .float32 : 409 , np .float64 : 204 , np .float128 : 102 },
685+ 100 : {float : 1797 , np .float32 : 1797 , np .float64 : 1797 , np .float128 : 1797 },
686+ None : {float : 1797 , np .float32 : 1797 , np .float64 : 1797 , np .float128 : 1797 },
687+ },
688+ REGRESSION : {
689+ 1 : {float : 655 , np .float32 : 1310 , np .float64 : 655 , np .float128 : 327 },
690+ 100 : {float : 5000 , np .float32 : 5000 , np .float64 : 5000 , np .float128 : 5000 },
691+ None : {float : 5000 , np .float32 : 5000 , np .float64 : 5000 , np .float128 : 5000 },
692+ },
693+ MULTIOUTPUT_REGRESSION : {
694+ 1 : {float : 655 , np .float32 : 1310 , np .float64 : 655 , np .float128 : 327 },
695+ 100 : {float : 5000 , np .float32 : 5000 , np .float64 : 5000 , np .float128 : 5000 },
696+ None : {float : 5000 , np .float32 : 5000 , np .float64 : 5000 , np .float128 : 5000 },
697+ }
677698 }
678699 mock = unittest .mock .Mock ()
679700 if task == BINARY_CLASSIFICATION :
680- X , y = sklearn .datasets .load_breast_cancer ( return_X_y = True )
701+ X , y = sklearn .datasets .make_hastie_10_2 ( )
681702 elif task == MULTICLASS_CLASSIFICATION :
682703 X , y = sklearn .datasets .load_digits (return_X_y = True )
683704 elif task == MULTILABEL_CLASSIFICATION :
@@ -686,22 +707,22 @@ def test_subsample_if_too_large(memory_limit, task):
686707 for i , j in enumerate (y_ ):
687708 y [i , j ] = 1
688709 elif task == REGRESSION :
689- X , y = sklearn .datasets .load_diabetes (return_X_y = True )
690- X = np .vstack ((X , X , X ))
691- y = np .vstack ((y .reshape ((- 1 , 1 )), y .reshape ((- 1 , 1 )), y .reshape ((- 1 , 1 ))))
710+ X , y = sklearn .datasets .make_friedman1 (n_samples = 5000 , n_features = 20 )
692711 elif task == MULTIOUTPUT_REGRESSION :
693- X , y = sklearn .datasets .load_diabetes ( return_X_y = True )
712+ X , y = sklearn .datasets .make_friedman1 ( n_samples = 5000 , n_features = 20 )
694713 y = np .vstack ((y , y )).transpose ()
695- X = np .vstack ((X , X , X ))
696- y = np .vstack ((y , y , y ))
697714 else :
698715 raise ValueError (task )
716+ X = X .astype (precision )
699717
700718 assert X .shape [0 ] == y .shape [0 ]
701719
702720 X_new , y_new = AutoML .subsample_if_too_large (X , y , mock , 1 , memory_limit , task )
703- assert X_new .shape [0 ] == fixture [task ][memory_limit ]
721+ assert X_new .shape [0 ] == fixture [task ][memory_limit ][ precision ]
704722 if memory_limit == 1 :
705- assert mock .warning .call_count == 1
723+ if precision in (np .float128 , np .float64 , float ):
724+ assert mock .warning .call_count == 2
725+ else :
726+ assert mock .warning .call_count == 1
706727 else :
707728 assert mock .warning .call_count == 0
0 commit comments