@@ -23,36 +23,57 @@ def multilabel_dataset():
2323 return Dataset .from_dict (data )
2424
2525
26- def test_gcn_scorer_fit_predict (multilabel_dataset ):
27- scorer = GCNScorer (
28- embedder_config = "prajjwal1/bert-tiny" ,
29- num_train_epochs = 1 ,
30- batch_size = 2 ,
31- )
26+ @pytest .fixture
27+ def multiclass_dataset ():
28+ data = {
29+ "train" : [
30+ {"utterance" : "utterance 1" , "label" : 0 },
31+ {"utterance" : "utterance 2" , "label" : 1 },
32+ {"utterance" : "utterance 3" , "label" : 2 },
33+ {"utterance" : "utterance 4" , "label" : 0 },
34+ ],
35+ "intents" : [
36+ {"id" : 0 , "name" : "intent_0" },
37+ {"id" : 1 , "name" : "intent_1" },
38+ {"id" : 2 , "name" : "intent_2" },
39+ ],
40+ }
41+ return Dataset .from_dict (data )
42+
43+
44+ def test_gcn_scorer_multilabel (multilabel_dataset ):
45+ scorer = GCNScorer (embedder_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 2 )
3246 train_utterances = multilabel_dataset ["train" ]["utterance" ]
3347 train_labels = multilabel_dataset ["train" ]["label" ]
34-
3548 scorer .fit (train_utterances , train_labels )
49+ test_utterances = ["test 1" , "test 2" ]
50+ predictions = scorer .predict (test_utterances )
3651
37- test_utterances = ["test utterance 1" , "test utterance 2" ]
52+ assert isinstance (predictions , np .ndarray )
53+ assert predictions .shape == (2 , 3 )
54+ assert np .all ((predictions >= 0 ) & (predictions <= 1 ))
55+
56+
57+ def test_gcn_scorer_multiclass (multiclass_dataset ):
58+ scorer = GCNScorer (embedder_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 2 )
59+ train_utterances = multiclass_dataset ["train" ]["utterance" ]
60+ train_labels = multiclass_dataset ["train" ]["label" ]
61+ scorer .fit (train_utterances , train_labels )
62+ test_utterances = ["test 1" , "test 2" ]
3863 predictions = scorer .predict (test_utterances )
3964
4065 assert isinstance (predictions , np .ndarray )
4166 assert predictions .shape == (2 , 3 )
4267 assert np .all ((predictions >= 0 ) & (predictions <= 1 ))
68+ np .testing .assert_allclose (predictions .sum (axis = 1 ), 1.0 , atol = 1e-6 )
4369
4470
4571def test_gcn_scorer_dump_load (tmp_path , multilabel_dataset ):
46- scorer = GCNScorer (
47- embedder_config = "prajjwal1/bert-tiny" ,
48- num_train_epochs = 1 ,
49- batch_size = 2 ,
50- )
72+ scorer = GCNScorer (embedder_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 2 )
5173 train_utterances = multilabel_dataset ["train" ]["utterance" ]
5274 train_labels = multilabel_dataset ["train" ]["label" ]
5375 scorer .fit (train_utterances , train_labels )
54-
55- test_utterances = ["test utterance 1" , "test utterance 2" ]
76+ test_utterances = ["test utterance 1" ]
5677 original_predictions = scorer .predict (test_utterances )
5778
5879 scorer .dump (str (tmp_path ))
0 commit comments