Skip to content

Commit d2c095c

Browse files
committed
fix gcn scorer tests
1 parent f8dd176 commit d2c095c

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

tests/modules/scoring/test_gcn_scorer.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,13 @@ def multiclass_dataset():
4545

4646
def test_gcn_scorer_multilabel(multilabel_dataset):
4747
torch.manual_seed(42)
48-
scorer = GCNScorer(embedder_config=get_test_embedder_config(), num_train_epochs=1, batch_size=2, seed=42)
48+
scorer = GCNScorer(
49+
embedder_config=get_test_embedder_config(),
50+
label_embedder_config=get_test_embedder_config(),
51+
num_train_epochs=1,
52+
batch_size=2,
53+
seed=42,
54+
)
4955
train_utterances = multilabel_dataset["train"]["utterance"]
5056
train_labels = multilabel_dataset["train"]["label"]
5157
descriptions = [intent.name for intent in multilabel_dataset.intents]
@@ -60,7 +66,13 @@ def test_gcn_scorer_multilabel(multilabel_dataset):
6066

6167
def test_gcn_scorer_multiclass(multiclass_dataset):
6268
torch.manual_seed(42)
63-
scorer = GCNScorer(embedder_config=get_test_embedder_config(), num_train_epochs=1, batch_size=2, seed=42)
69+
scorer = GCNScorer(
70+
embedder_config=get_test_embedder_config(),
71+
label_embedder_config=get_test_embedder_config(),
72+
num_train_epochs=1,
73+
batch_size=2,
74+
seed=42,
75+
)
6476
train_utterances = multiclass_dataset["train"]["utterance"]
6577
train_labels = multiclass_dataset["train"]["label"]
6678
descriptions = [intent.name for intent in multiclass_dataset.intents]
@@ -76,7 +88,13 @@ def test_gcn_scorer_multiclass(multiclass_dataset):
7688

7789
def test_gcn_scorer_dump_load(tmp_path, multilabel_dataset):
7890
torch.manual_seed(42)
79-
scorer = GCNScorer(embedder_config=get_test_embedder_config(), num_train_epochs=1, batch_size=2, seed=42)
91+
scorer = GCNScorer(
92+
embedder_config=get_test_embedder_config(),
93+
label_embedder_config=get_test_embedder_config(),
94+
num_train_epochs=1,
95+
batch_size=2,
96+
seed=42,
97+
)
8098
train_utterances = multilabel_dataset["train"]["utterance"]
8199
train_labels = multilabel_dataset["train"]["label"]
82100
descriptions = [intent.name for intent in multilabel_dataset.intents]

0 commit comments

Comments
 (0)