@@ -63,20 +63,29 @@ def test_multilabel_train_test_split_few_shot(dataset_unsplitted):
6363 assert dataset .get_n_classes (Split .TRAIN ) == dataset .get_n_classes (Split .TEST )
6464
6565
66- def test_multiclass_train_test_split_few_shot (dataset_unsplitted ):
66+ @pytest .mark .parametrize ("allow_oos_in_train" , [True , False ])
67+ def test_multiclass_train_test_split_few_shot (dataset_unsplitted , allow_oos_in_train ):
68+ train_num_rows = 10 if allow_oos_in_train else 8
69+ test_num_rows = 26 if allow_oos_in_train else 28
70+ examples_per_intent = 2
71+
6772 dataset = dataset_unsplitted
6873 dataset [Split .TRAIN ], dataset [Split .TEST ] = split_dataset (
6974 dataset ,
7075 split = Split .TRAIN ,
7176 test_size = 0.5 ,
7277 random_seed = 42 ,
73- allow_oos_in_train = False ,
78+ allow_oos_in_train = allow_oos_in_train ,
7479 is_few_shot = True ,
75- examples_per_intent = 2 ,
80+ examples_per_intent = examples_per_intent ,
7681 )
7782
7883 assert Split .TRAIN in dataset
7984 assert Split .TEST in dataset
80- assert dataset [Split .TRAIN ].num_rows == 8
81- assert dataset [Split .TEST ].num_rows == 28
85+ assert dataset [Split .TRAIN ].num_rows == train_num_rows
86+ assert dataset [Split .TEST ].num_rows == test_num_rows
8287 assert dataset .get_n_classes (Split .TRAIN ) == dataset .get_n_classes (Split .TEST )
88+
89+ for class_id in range (dataset .get_n_classes (Split .TRAIN )):
90+ class_ds = dataset [Split .TRAIN ].filter (lambda x : x ["label" ] == class_id ) # noqa: B023
91+ assert len (class_ds ) <= examples_per_intent
0 commit comments