Skip to content

Commit ae2011a

Browse files
committed
Refactor test for ptuning
1 parent c414e1a commit ae2011a

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/modules/scoring/test_ptuning.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ def test_ptuning_scorer_dump_load(dataset):
1414
data_handler = DataHandler(dataset)
1515

1616
scorer_original = PTuningScorer(
17-
base_model_config="prajjwal1/bert-tiny",
17+
classification_model_config="prajjwal1/bert-tiny",
1818
num_train_epochs=1,
1919
batch_size=8,
2020
task_type="SEQ_CLS",
2121
num_virtual_tokens=10,
22+
seed=42,
2223
)
2324
scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
2425

@@ -34,11 +35,12 @@ def test_ptuning_scorer_dump_load(dataset):
3435
scorer_original.dump(str(temp_dir_path))
3536

3637
scorer_loaded = PTuningScorer(
37-
base_model_config="prajjwal1/bert-tiny",
38+
classification_model_config="prajjwal1/bert-tiny",
3839
num_train_epochs=1,
3940
batch_size=8,
4041
task_type="SEQ_CLS",
4142
num_virtual_tokens=10,
43+
seed=42,
4244
)
4345
scorer_loaded.load(str(temp_dir_path))
4446

@@ -61,11 +63,12 @@ def test_ptuning_prediction(dataset):
6163
data_handler = DataHandler(dataset)
6264

6365
scorer = PTuningScorer(
64-
base_model_config="prajjwal1/bert-tiny",
66+
classification_model_config="prajjwal1/bert-tiny",
6567
num_train_epochs=1,
6668
batch_size=8,
6769
task_type="SEQ_CLS",
6870
num_virtual_tokens=10,
71+
seed=42,
6972
)
7073

7174
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
@@ -100,11 +103,12 @@ def test_ptuning_cache_clearing(dataset):
100103
data_handler = DataHandler(dataset)
101104

102105
scorer = PTuningScorer(
103-
base_model_config="prajjwal1/bert-tiny",
106+
classification_model_config="prajjwal1/bert-tiny",
104107
num_train_epochs=1,
105108
batch_size=8,
106109
task_type="SEQ_CLS",
107110
num_virtual_tokens=20,
111+
seed=42,
108112
)
109113

110114
scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0))

0 commit comments

Comments
 (0)