|
| 1 | +#!/usr/bin/env python3 |
| 2 | +# MIT License |
| 3 | + |
| 4 | +# Copyright (c) 2025 The HuggingFace Team |
| 5 | + |
| 6 | +""" |
| 7 | +Quick smoke tests for vLLM 0.11.0 compatibility. |
| 8 | +
|
| 9 | +This script performs basic sanity checks to verify that vLLM 0.11.0 |
| 10 | +works with lighteval's integration. |
| 11 | +""" |
| 12 | + |
| 13 | +import sys |
| 14 | + |
| 15 | + |
| 16 | +def test_vllm_import(): |
| 17 | + """Test basic vLLM import.""" |
| 18 | + print("Testing vLLM import...") |
| 19 | + try: |
| 20 | + import vllm |
| 21 | + |
| 22 | + print(f"✓ vLLM imported successfully. Version: {vllm.__version__}") |
| 23 | + return True |
| 24 | + except ImportError as e: |
| 25 | + print(f"✗ Failed to import vLLM: {e}") |
| 26 | + return False |
| 27 | + |
| 28 | + |
| 29 | +def test_vllm_version(): |
| 30 | + """Test vLLM version is 0.11.0 or higher.""" |
| 31 | + print("\nTesting vLLM version...") |
| 32 | + try: |
| 33 | + import vllm |
| 34 | + |
| 35 | + version = vllm.__version__ |
| 36 | + major, minor = map(int, version.split(".")[:2]) |
| 37 | + |
| 38 | + if major > 0 or (major == 0 and minor >= 11): |
| 39 | + print(f"✓ vLLM version {version} is 0.11.0 or higher") |
| 40 | + return True |
| 41 | + else: |
| 42 | + print(f"✗ vLLM version {version} is lower than 0.11.0") |
| 43 | + return False |
| 44 | + except Exception as e: |
| 45 | + print(f"✗ Failed to check vLLM version: {e}") |
| 46 | + return False |
| 47 | + |
| 48 | + |
| 49 | +def test_v1_engine_imports(): |
| 50 | + """Test V1 engine imports.""" |
| 51 | + print("\nTesting V1 engine imports...") |
| 52 | + try: |
| 53 | + from vllm.v1.engine.async_llm import AsyncEngineArgs, AsyncLLM # noqa: F401 |
| 54 | + |
| 55 | + print("✓ V1 AsyncLLM engine imports successful") |
| 56 | + return True |
| 57 | + except ImportError as e: |
| 58 | + print(f"✗ Failed to import V1 engine: {e}") |
| 59 | + return False |
| 60 | + |
| 61 | + |
| 62 | +def test_v0_engine_removed(): |
| 63 | + """Test that V0 engine is removed.""" |
| 64 | + print("\nTesting V0 engine removal...") |
| 65 | + try: |
| 66 | + from vllm.engine.async_llm import AsyncLLMEngine # noqa: F401 |
| 67 | + |
| 68 | + print("✗ V0 engine still present (should be removed in 0.11.0)") |
| 69 | + return False |
| 70 | + except ImportError: |
| 71 | + print("✓ V0 engine properly removed") |
| 72 | + return True |
| 73 | + |
| 74 | + |
| 75 | +def test_core_imports(): |
| 76 | + """Test core vLLM component imports.""" |
| 77 | + print("\nTesting core vLLM imports...") |
| 78 | + try: |
| 79 | + from vllm import LLM, RequestOutput, SamplingParams # noqa: F401 |
| 80 | + from vllm.distributed.parallel_state import ( # noqa: F401 |
| 81 | + destroy_distributed_environment, |
| 82 | + destroy_model_parallel, |
| 83 | + ) |
| 84 | + from vllm.tokenizers import get_tokenizer # noqa: F401 |
| 85 | + |
| 86 | + print("✓ All core vLLM components imported successfully") |
| 87 | + return True |
| 88 | + except ImportError as e: |
| 89 | + print(f"✗ Failed to import core components: {e}") |
| 90 | + return False |
| 91 | + |
| 92 | + |
| 93 | +def test_lighteval_vllm_imports(): |
| 94 | + """Test lighteval's vLLM integration imports.""" |
| 95 | + print("\nTesting lighteval vLLM integration imports...") |
| 96 | + try: |
| 97 | + from lighteval.models.vllm.vllm_model import AsyncVLLMModel, VLLMModel, VLLMModelConfig # noqa: F401 |
| 98 | + |
| 99 | + print("✓ Lighteval vLLM integration imports successful") |
| 100 | + return True |
| 101 | + except ImportError as e: |
| 102 | + print(f"✗ Failed to import lighteval vLLM integration: {e}") |
| 103 | + return False |
| 104 | + |
| 105 | + |
| 106 | +def test_model_config_creation(): |
| 107 | + """Test VLLMModelConfig creation.""" |
| 108 | + print("\nTesting VLLMModelConfig creation...") |
| 109 | + try: |
| 110 | + from lighteval.models.vllm.vllm_model import VLLMModelConfig |
| 111 | + |
| 112 | + config = VLLMModelConfig( |
| 113 | + model_name="HuggingFaceTB/SmolLM2-135M-Instruct", |
| 114 | + tensor_parallel_size=1, |
| 115 | + data_parallel_size=1, |
| 116 | + gpu_memory_utilization=0.3, |
| 117 | + ) |
| 118 | + |
| 119 | + print("✓ VLLMModelConfig created successfully") |
| 120 | + print(f" Model: {config.model_name}") |
| 121 | + print(f" TP: {config.tensor_parallel_size}, DP: {config.data_parallel_size}") |
| 122 | + return True |
| 123 | + except Exception as e: |
| 124 | + print(f"✗ Failed to create VLLMModelConfig: {e}") |
| 125 | + return False |
| 126 | + |
| 127 | + |
| 128 | +def test_sampling_params(): |
| 129 | + """Test SamplingParams creation.""" |
| 130 | + print("\nTesting SamplingParams creation...") |
| 131 | + try: |
| 132 | + from vllm import SamplingParams |
| 133 | + |
| 134 | + params = SamplingParams(temperature=0.7, top_p=0.9, top_k=50, max_tokens=100, stop=["</s>"]) |
| 135 | + |
| 136 | + print("✓ SamplingParams created successfully") |
| 137 | + print(f" Temperature: {params.temperature}") |
| 138 | + print(f" Top-p: {params.top_p}") |
| 139 | + print(f" Max tokens: {params.max_tokens}") |
| 140 | + return True |
| 141 | + except Exception as e: |
| 142 | + print(f"✗ Failed to create SamplingParams: {e}") |
| 143 | + return False |
| 144 | + |
| 145 | + |
| 146 | +def main(): |
| 147 | + """Run all smoke tests.""" |
| 148 | + print("=" * 60) |
| 149 | + print("vLLM 0.11.0 Compatibility Smoke Tests") |
| 150 | + print("=" * 60) |
| 151 | + |
| 152 | + tests = [ |
| 153 | + test_vllm_import, |
| 154 | + test_vllm_version, |
| 155 | + test_v1_engine_imports, |
| 156 | + test_v0_engine_removed, |
| 157 | + test_core_imports, |
| 158 | + test_lighteval_vllm_imports, |
| 159 | + test_model_config_creation, |
| 160 | + test_sampling_params, |
| 161 | + ] |
| 162 | + |
| 163 | + results = [] |
| 164 | + for test in tests: |
| 165 | + try: |
| 166 | + result = test() |
| 167 | + results.append(result) |
| 168 | + except Exception as e: |
| 169 | + print(f"✗ Test {test.__name__} crashed: {e}") |
| 170 | + results.append(False) |
| 171 | + |
| 172 | + print("\n" + "=" * 60) |
| 173 | + passed = sum(results) |
| 174 | + total = len(results) |
| 175 | + print(f"Results: {passed}/{total} tests passed") |
| 176 | + |
| 177 | + if passed == total: |
| 178 | + print("✓ All smoke tests passed!") |
| 179 | + print("=" * 60) |
| 180 | + return 0 |
| 181 | + else: |
| 182 | + print(f"✗ {total - passed} test(s) failed") |
| 183 | + print("=" * 60) |
| 184 | + return 1 |
| 185 | + |
| 186 | + |
| 187 | +if __name__ == "__main__": |
| 188 | + sys.exit(main()) |
0 commit comments