Skip to content
This repository was archived by the owner on Sep 4, 2025. It is now read-only.

Commit a246912

Browse files
authored
[misc][ci] fix quant test (vllm-project#8449)
1 parent 06311e2 commit a246912

File tree

2 files changed

+21
-15
lines changed

2 files changed

+21
-15
lines changed

tests/quantization/test_bitsandbytes.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010

1111
from tests.quantization.utils import is_quant_method_supported
1212

13+
from ..utils import fork_new_process_for_each_test
14+
1315
models_4bit_to_test = [
1416
('huggyllama/llama-7b', 'quantize model inflight'),
1517
]
@@ -29,6 +31,7 @@
2931
@pytest.mark.skipif(not is_quant_method_supported("bitsandbytes"),
3032
reason='bitsandbytes is not supported on this GPU type.')
3133
@pytest.mark.parametrize("model_name, description", models_4bit_to_test)
34+
@fork_new_process_for_each_test
3235
def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
3336
model_name, description) -> None:
3437

@@ -41,6 +44,7 @@ def test_load_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
4144
reason='bitsandbytes is not supported on this GPU type.')
4245
@pytest.mark.parametrize("model_name, description",
4346
models_pre_qaunt_4bit_to_test)
47+
@fork_new_process_for_each_test
4448
def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
4549
model_name, description) -> None:
4650

@@ -52,6 +56,7 @@ def test_load_pre_quant_4bit_bnb_model(hf_runner, vllm_runner, example_prompts,
5256
reason='bitsandbytes is not supported on this GPU type.')
5357
@pytest.mark.parametrize("model_name, description",
5458
models_pre_quant_8bit_to_test)
59+
@fork_new_process_for_each_test
5560
def test_load_8bit_bnb_model(hf_runner, vllm_runner, example_prompts,
5661
model_name, description) -> None:
5762

@@ -77,18 +82,8 @@ def validate_generated_texts(hf_runner,
7782
model_name,
7883
hf_model_kwargs=None):
7984

80-
if hf_model_kwargs is None:
81-
hf_model_kwargs = {}
82-
83-
# Run with HF runner
84-
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
85-
hf_outputs = llm.generate_greedy(prompts, 8)
86-
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
87-
88-
# Clean up the GPU memory for the next test
89-
torch.cuda.synchronize()
90-
gc.collect()
91-
torch.cuda.empty_cache()
85+
# NOTE: run vLLM first, as it requires a clean process
86+
# when using distributed inference
9287

9388
#Run with vLLM runner
9489
with vllm_runner(model_name,
@@ -104,6 +99,19 @@ def validate_generated_texts(hf_runner,
10499
gc.collect()
105100
torch.cuda.empty_cache()
106101

102+
if hf_model_kwargs is None:
103+
hf_model_kwargs = {}
104+
105+
# Run with HF runner
106+
with hf_runner(model_name, model_kwargs=hf_model_kwargs) as llm:
107+
hf_outputs = llm.generate_greedy(prompts, 8)
108+
hf_logs = log_generated_texts(prompts, hf_outputs, "HfRunner")
109+
110+
# Clean up the GPU memory for the next test
111+
torch.cuda.synchronize()
112+
gc.collect()
113+
torch.cuda.empty_cache()
114+
107115
# Compare the generated strings
108116
for hf_log, vllm_log in zip(hf_logs, vllm_logs):
109117
hf_str = hf_log["generated_text"]

tests/quantization/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,10 @@
1-
import torch
2-
31
from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS
42
from vllm.platforms import current_platform
53

64

75
def is_quant_method_supported(quant_method: str) -> bool:
86
# Currently, all quantization methods require Nvidia or AMD GPUs
9-
if not torch.cuda.is_available():
7+
if not (current_platform.is_cuda() or current_platform.is_rocm()):
108
return False
119

1210
capability = current_platform.get_device_capability()

0 commit comments

Comments
 (0)