1717@pytest .mark .parametrize ("agg" , ALLOWED_AGGREGATIONS )
1818def test_inside_knn_ad (metric , agg ):
1919 if "binary" in metric :
20- X_train , X_test = get_data_inside_ad (binarize = True )
20+ X_train , X_test = get_data_inside_ad (n_train = 100 , n_test = 10 , binarize = True )
2121 else :
22- X_train , X_test = get_data_inside_ad ()
22+ X_train , X_test = get_data_inside_ad (n_train = 100 , n_test = 10 )
2323
2424 ad_checker = KNNADChecker (k = 3 , agg = agg )
2525 ad_checker .fit (X_train )
@@ -39,9 +39,9 @@ def test_inside_knn_ad(metric, agg):
3939@pytest .mark .parametrize ("agg" , ALLOWED_AGGREGATIONS )
4040def test_outside_knn_ad (metric , agg ):
4141 if "binary" in metric :
42- X_train , X_test = get_data_outside_ad (binarize = True )
42+ X_train , X_test = get_data_outside_ad (n_train = 100 , n_test = 10 , binarize = True )
4343 else :
44- X_train , X_test = get_data_outside_ad ()
44+ X_train , X_test = get_data_outside_ad (n_train = 100 , n_test = 10 )
4545
4646 ad_checker = KNNADChecker (k = 3 , metric = metric , agg = agg )
4747 ad_checker .fit (X_train )
@@ -60,9 +60,9 @@ def test_outside_knn_ad(metric, agg):
6060@pytest .mark .parametrize ("agg" , ALLOWED_AGGREGATIONS )
6161def test_knn_different_k_values (metric , agg ):
6262 if "binary" in metric :
63- X_train , X_test = get_data_inside_ad (binarize = True )
63+ X_train , X_test = get_data_inside_ad (n_train = 100 , n_test = 10 , binarize = True )
6464 else :
65- X_train , X_test = get_data_inside_ad ()
65+ X_train , X_test = get_data_inside_ad (n_train = 100 , n_test = 10 )
6666
6767 # smaller k, stricter check
6868 ad_checker_k1 = KNNADChecker (k = 1 , metric = metric , agg = agg )
@@ -84,9 +84,9 @@ def test_knn_different_k_values(metric, agg):
8484def test_knn_pass_y_train (metric , agg ):
8585 # smoke test, should not throw errors
8686 if "binary" in metric :
87- X_train , _ = get_data_inside_ad (binarize = True )
87+ X_train , _ = get_data_inside_ad (n_train = 100 , n_test = 10 , binarize = True )
8888 else :
89- X_train , _ = get_data_inside_ad ()
89+ X_train , _ = get_data_inside_ad (n_train = 100 , n_test = 10 )
9090
9191 y_train = np .zeros (len (X_train ))
9292 ad_checker = KNNADChecker (k = 3 , metric = metric , agg = agg )
@@ -97,9 +97,9 @@ def test_knn_pass_y_train(metric, agg):
9797@pytest .mark .parametrize ("agg" , ALLOWED_AGGREGATIONS )
9898def test_knn_invalid_k (metric , agg ):
9999 if "binary" in metric :
100- X_train , _ = get_data_inside_ad (binarize = True )
100+ X_train , _ = get_data_inside_ad (n_train = 100 , n_test = 10 , binarize = True )
101101 else :
102- X_train , _ = get_data_inside_ad ()
102+ X_train , _ = get_data_inside_ad (n_train = 100 , n_test = 10 )
103103
104104 with pytest .raises (
105105 ValueError ,
@@ -110,7 +110,7 @@ def test_knn_invalid_k(metric, agg):
110110
111111
112112def test_knn_invalid_metric ():
113- X_train , _ = get_data_inside_ad ()
113+ X_train , _ = get_data_inside_ad (n_train = 100 )
114114 ad_checker = KNNADChecker (k = 3 , metric = "euclidean" )
115115 with pytest .raises (KeyError ):
116116 ad_checker .fit (X_train )
0 commit comments