Skip to content

Commit 889c2e5

Browse files
committed
Update test_lora.py
1 parent 8abba9f commit 889c2e5

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

tests/modules/scoring/test_lora.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ def test_lora_scorer_dump_load(dataset):
1414
data_handler = DataHandler(dataset)
1515

1616
# Create and train scorer
17-
scorer_original = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
17+
scorer_original = BERTLoRAScorer(
18+
classification_model_config="prajjwal1/bert-tiny",
19+
num_train_epochs=1,
20+
batch_size=8
21+
)
1822
scorer_original.fit(data_handler.train_utterances(0), data_handler.train_labels(0))
1923

2024
# Test data
@@ -33,7 +37,11 @@ def test_lora_scorer_dump_load(dataset):
3337
scorer_original.dump(str(temp_dir_path))
3438

3539
# Create a new scorer and load saved model
36-
scorer_loaded = BERTLoRAScorer(classification_model_config="prajjwal1/bert-tiny", num_train_epochs=1, batch_size=8)
40+
scorer_loaded = BERTLoRAScorer(
41+
classification_model_config="prajjwal1/bert-tiny",
42+
num_train_epochs=1,
43+
batch_size=8
44+
)
3745
scorer_loaded.load(str(temp_dir_path))
3846

3947
# Verify model and tokenizer are loaded
@@ -113,4 +121,4 @@ def test_lora_cache_clearing(dataset):
113121

114122
# Should raise exception after clearing cache
115123
with pytest.raises(RuntimeError):
116-
scorer.predict(test_data)
124+
scorer.predict(test_data)

0 commit comments

Comments
 (0)