@@ -14,7 +14,11 @@ def test_lora_scorer_dump_load(dataset):
1414 data_handler = DataHandler (dataset )
1515
1616 # Create and train scorer
17- scorer_original = BERTLoRAScorer (classification_model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
17+ scorer_original = BERTLoRAScorer (
18+ classification_model_config = "prajjwal1/bert-tiny" ,
19+ num_train_epochs = 1 ,
20+ batch_size = 8
21+ )
1822 scorer_original .fit (data_handler .train_utterances (0 ), data_handler .train_labels (0 ))
1923
2024 # Test data
@@ -33,7 +37,11 @@ def test_lora_scorer_dump_load(dataset):
3337 scorer_original .dump (str (temp_dir_path ))
3438
3539 # Create a new scorer and load saved model
36- scorer_loaded = BERTLoRAScorer (classification_model_config = "prajjwal1/bert-tiny" , num_train_epochs = 1 , batch_size = 8 )
40+ scorer_loaded = BERTLoRAScorer (
41+ classification_model_config = "prajjwal1/bert-tiny" ,
42+ num_train_epochs = 1 ,
43+ batch_size = 8
44+ )
3745 scorer_loaded .load (str (temp_dir_path ))
3846
3947 # Verify model and tokenizer are loaded
@@ -113,4 +121,4 @@ def test_lora_cache_clearing(dataset):
113121
114122 # Should raise exception after clearing cache
115123 with pytest .raises (RuntimeError ):
116- scorer .predict (test_data )
124+ scorer .predict (test_data )
0 commit comments