@@ -14,11 +14,12 @@ def test_ptuning_scorer_dump_load(dataset):
1414 data_handler = DataHandler (dataset )
1515
1616 scorer_original = PTuningScorer (
17- base_model_config = "prajjwal1/bert-tiny" ,
17+ classification_model_config = "prajjwal1/bert-tiny" ,
1818 num_train_epochs = 1 ,
1919 batch_size = 8 ,
2020 task_type = "SEQ_CLS" ,
2121 num_virtual_tokens = 10 ,
22+ seed = 42 ,
2223 )
2324 scorer_original .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
2425
@@ -34,11 +35,12 @@ def test_ptuning_scorer_dump_load(dataset):
3435 scorer_original .dump (str (temp_dir_path ))
3536
3637 scorer_loaded = PTuningScorer (
37- base_model_config = "prajjwal1/bert-tiny" ,
38+ classification_model_config = "prajjwal1/bert-tiny" ,
3839 num_train_epochs = 1 ,
3940 batch_size = 8 ,
4041 task_type = "SEQ_CLS" ,
4142 num_virtual_tokens = 10 ,
43+ seed = 42 ,
4244 )
4345 scorer_loaded .load (str (temp_dir_path ))
4446
@@ -61,11 +63,12 @@ def test_ptuning_prediction(dataset):
6163 data_handler = DataHandler (dataset )
6264
6365 scorer = PTuningScorer (
64- base_model_config = "prajjwal1/bert-tiny" ,
66+ classification_model_config = "prajjwal1/bert-tiny" ,
6567 num_train_epochs = 1 ,
6668 batch_size = 8 ,
6769 task_type = "SEQ_CLS" ,
6870 num_virtual_tokens = 10 ,
71+ seed = 42 ,
6972 )
7073
7174 scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
@@ -100,11 +103,12 @@ def test_ptuning_cache_clearing(dataset):
100103 data_handler = DataHandler (dataset )
101104
102105 scorer = PTuningScorer (
103- base_model_config = "prajjwal1/bert-tiny" ,
106+ classification_model_config = "prajjwal1/bert-tiny" ,
104107 num_train_epochs = 1 ,
105108 batch_size = 8 ,
106109 task_type = "SEQ_CLS" ,
107110 num_virtual_tokens = 20 ,
111+ seed = 42 ,
108112 )
109113
110114 scorer .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
0 commit comments