10
10
11
11
from tests .quantization .utils import is_quant_method_supported
12
12
13
+ from ..utils import fork_new_process_for_each_test
14
+
13
15
models_4bit_to_test = [
14
16
('huggyllama/llama-7b' , 'quantize model inflight' ),
15
17
]
29
31
@pytest .mark .skipif (not is_quant_method_supported ("bitsandbytes" ),
30
32
reason = 'bitsandbytes is not supported on this GPU type.' )
31
33
@pytest .mark .parametrize ("model_name, description" , models_4bit_to_test )
34
+ @fork_new_process_for_each_test
32
35
def test_load_4bit_bnb_model (hf_runner , vllm_runner , example_prompts ,
33
36
model_name , description ) -> None :
34
37
@@ -41,6 +44,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
41
44
reason = 'bitsandbytes is not supported on this GPU type.' )
42
45
@pytest .mark .parametrize ("model_name, description" ,
43
46
models_pre_qaunt_4bit_to_test )
47
+ @fork_new_process_for_each_test
44
48
def test_load_pre_quant_4bit_bnb_model (hf_runner , vllm_runner , example_prompts ,
45
49
model_name , description ) -> None :
46
50
@@ -52,6 +56,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
52
56
reason = 'bitsandbytes is not supported on this GPU type.' )
53
57
@pytest .mark .parametrize ("model_name, description" ,
54
58
models_pre_quant_8bit_to_test )
59
+ @fork_new_process_for_each_test
55
60
def test_load_8bit_bnb_model (hf_runner , vllm_runner , example_prompts ,
56
61
model_name , description ) -> None :
57
62
@@ -77,18 +82,8 @@ def validate_generated_texts(hf_runner,
77
82
model_name ,
78
83
hf_model_kwargs = None ):
79
84
80
- if hf_model_kwargs is None :
81
- hf_model_kwargs = {}
82
-
83
- # Run with HF runner
84
- with hf_runner (model_name , model_kwargs = hf_model_kwargs ) as llm :
85
- hf_outputs = llm .generate_greedy (prompts , 8 )
86
- hf_logs = log_generated_texts (prompts , hf_outputs , "HfRunner" )
87
-
88
- # Clean up the GPU memory for the next test
89
- torch .cuda .synchronize ()
90
- gc .collect ()
91
- torch .cuda .empty_cache ()
85
+ # NOTE: run vLLM first, as it requires a clean process
86
+ # when using distributed inference
92
87
93
88
#Run with vLLM runner
94
89
with vllm_runner (model_name ,
@@ -104,6 +99,19 @@ def validate_generated_texts(hf_runner,
104
99
gc .collect ()
105
100
torch .cuda .empty_cache ()
106
101
102
+ if hf_model_kwargs is None :
103
+ hf_model_kwargs = {}
104
+
105
+ # Run with HF runner
106
+ with hf_runner (model_name , model_kwargs = hf_model_kwargs ) as llm :
107
+ hf_outputs = llm .generate_greedy (prompts , 8 )
108
+ hf_logs = log_generated_texts (prompts , hf_outputs , "HfRunner" )
109
+
110
+ # Clean up the GPU memory for the next test
111
+ torch .cuda .synchronize ()
112
+ gc .collect ()
113
+ torch .cuda .empty_cache ()
114
+
107
115
# Compare the generated strings
108
116
for hf_log , vllm_log in zip (hf_logs , vllm_logs ):
109
117
hf_str = hf_log ["generated_text" ]
0 commit comments