1- import pytest
21from unittest .mock import MagicMock
2+
33from datasets import Dataset
4+ from transformers import BitsAndBytesConfig
45
56from llmtune .inference .lora import LoRAInference
6- from llmtune .utils .save_utils import DirectoryHelper
77from test_utils .test_config import get_sample_config # Adjust import path as needed
88
9- from transformers import BitsAndBytesConfig
10-
119
1210def test_lora_inference_initialization (mocker ):
1311 # Mock dependencies
1412 mock_model = mocker .patch (
1513 "llmtune.inference.lora.AutoPeftModelForCausalLM.from_pretrained" ,
1614 return_value = MagicMock (),
1715 )
18- mock_tokenizer = mocker .patch (
19- "llmtune.inference.lora.AutoTokenizer.from_pretrained" , return_value = MagicMock ()
20- )
16+ mock_tokenizer = mocker .patch ("llmtune.inference.lora.AutoTokenizer.from_pretrained" , return_value = MagicMock ())
2117
2218 # Mock configuration and directory helper
2319 config = get_sample_config ()
24- dir_helper = MagicMock (
25- save_paths = MagicMock (results = "results_dir" , weights = "weights_dir" )
26- )
20+ dir_helper = MagicMock (save_paths = MagicMock (results = "results_dir" , weights = "weights_dir" ))
2721 test_dataset = Dataset .from_dict (
2822 {
2923 "formatted_prompt" : ["prompt1" , "prompt2" ],
3024 "label_column_name" : ["label1" , "label2" ],
3125 }
3226 )
3327
34- inference = LoRAInference (
28+ _ = LoRAInference (
3529 test_dataset = test_dataset ,
3630 label_column_name = "label_column_name" ,
3731 config = config ,
@@ -45,34 +39,24 @@ def test_lora_inference_initialization(mocker):
4539 device_map = config .model .device_map ,
4640 attn_implementation = config .model .attn_implementation ,
4741 )
48- mock_tokenizer .assert_called_once_with (
49- "weights_dir" , device_map = config .model .device_map
50- )
42+ mock_tokenizer .assert_called_once_with ("weights_dir" , device_map = config .model .device_map )
5143
5244
5345def test_infer_all (mocker ):
5446 mocker .patch (
5547 "llmtune.inference.lora.AutoPeftModelForCausalLM.from_pretrained" ,
5648 return_value = MagicMock (),
5749 )
58- mocker .patch (
59- "llmtune.inference.lora.AutoTokenizer.from_pretrained" , return_value = MagicMock ()
60- )
50+ mocker .patch ("llmtune.inference.lora.AutoTokenizer.from_pretrained" , return_value = MagicMock ())
6151 mocker .patch ("os.makedirs" )
6252 mock_open = mocker .patch ("builtins.open" , mocker .mock_open ())
6353 mock_csv_writer = mocker .patch ("csv.writer" )
6454
65- mock_infer_one = mocker .patch .object (
66- LoRAInference , "infer_one" , return_value = "predicted"
67- )
55+ mock_infer_one = mocker .patch .object (LoRAInference , "infer_one" , return_value = "predicted" )
6856
6957 config = get_sample_config ()
70- dir_helper = MagicMock (
71- save_paths = MagicMock (results = "results_dir" , weights = "weights_dir" )
72- )
73- test_dataset = Dataset .from_dict (
74- {"formatted_prompt" : ["prompt1" ], "label_column_name" : ["label1" ]}
75- )
58+ dir_helper = MagicMock (save_paths = MagicMock (results = "results_dir" , weights = "weights_dir" ))
59+ test_dataset = Dataset .from_dict ({"formatted_prompt" : ["prompt1" ], "label_column_name" : ["label1" ]})
7660
7761 inference = LoRAInference (
7862 test_dataset = test_dataset ,
0 commit comments