@@ -45,7 +45,13 @@ def multiclass_dataset():
4545
4646def test_gcn_scorer_multilabel (multilabel_dataset ):
4747 torch .manual_seed (42 )
48- scorer = GCNScorer (embedder_config = get_test_embedder_config (), num_train_epochs = 1 , batch_size = 2 , seed = 42 )
48+ scorer = GCNScorer (
49+ embedder_config = get_test_embedder_config (),
50+ label_embedder_config = get_test_embedder_config (),
51+ num_train_epochs = 1 ,
52+ batch_size = 2 ,
53+ seed = 42 ,
54+ )
4955 train_utterances = multilabel_dataset ["train" ]["utterance" ]
5056 train_labels = multilabel_dataset ["train" ]["label" ]
5157 descriptions = [intent .name for intent in multilabel_dataset .intents ]
@@ -60,7 +66,13 @@ def test_gcn_scorer_multilabel(multilabel_dataset):
6066
6167def test_gcn_scorer_multiclass (multiclass_dataset ):
6268 torch .manual_seed (42 )
63- scorer = GCNScorer (embedder_config = get_test_embedder_config (), num_train_epochs = 1 , batch_size = 2 , seed = 42 )
69+ scorer = GCNScorer (
70+ embedder_config = get_test_embedder_config (),
71+ label_embedder_config = get_test_embedder_config (),
72+ num_train_epochs = 1 ,
73+ batch_size = 2 ,
74+ seed = 42 ,
75+ )
6476 train_utterances = multiclass_dataset ["train" ]["utterance" ]
6577 train_labels = multiclass_dataset ["train" ]["label" ]
6678 descriptions = [intent .name for intent in multiclass_dataset .intents ]
@@ -76,7 +88,13 @@ def test_gcn_scorer_multiclass(multiclass_dataset):
7688
7789def test_gcn_scorer_dump_load (tmp_path , multilabel_dataset ):
7890 torch .manual_seed (42 )
79- scorer = GCNScorer (embedder_config = get_test_embedder_config (), num_train_epochs = 1 , batch_size = 2 , seed = 42 )
91+ scorer = GCNScorer (
92+ embedder_config = get_test_embedder_config (),
93+ label_embedder_config = get_test_embedder_config (),
94+ num_train_epochs = 1 ,
95+ batch_size = 2 ,
96+ seed = 42 ,
97+ )
8098 train_utterances = multilabel_dataset ["train" ]["utterance" ]
8199 train_labels = multilabel_dataset ["train" ]["label" ]
82100 descriptions = [intent .name for intent in multilabel_dataset .intents ]
0 commit comments