Skip to content

Commit f77bfba

Browse files
author
Aleksandr Malyshev
committed
sync with 0909_rc2 changes
1 parent fc0dbad commit f77bfba

File tree

3 files changed

+24
-19
lines changed

3 files changed

+24
-19
lines changed

tests/entrypoints/openai/test_serving_chat.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import Any, Optional
88
from unittest.mock import MagicMock
99

10+
import pytest
11+
1012
from vllm.config import MultiModalConfig
1113
from vllm.engine.multiprocessing.client import MQLLMEngineClient
1214
from vllm.entrypoints.openai.protocol import ChatCompletionRequest
@@ -73,7 +75,8 @@ def test_async_serving_chat_init():
7375
assert serving_completion.chat_template == CHAT_TEMPLATE
7476

7577

76-
def test_serving_chat_should_set_correct_max_tokens():
78+
@pytest.mark.asyncio
79+
async def test_serving_chat_should_set_correct_max_tokens():
7780
mock_engine = MagicMock(spec=MQLLMEngineClient)
7881
mock_engine.get_tokenizer.return_value = get_tokenizer(MODEL_NAME)
7982
mock_engine.errored = False
@@ -88,6 +91,7 @@ def test_serving_chat_should_set_correct_max_tokens():
8891
chat_template=CHAT_TEMPLATE,
8992
chat_template_content_format="auto",
9093
request_logger=None)
94+
9195
req = ChatCompletionRequest(
9296
model=MODEL_NAME,
9397
messages=[{
@@ -98,13 +102,13 @@ def test_serving_chat_should_set_correct_max_tokens():
98102
)
99103

100104
with suppress(Exception):
101-
asyncio.run(serving_chat.create_chat_completion(req))
105+
await serving_chat.create_chat_completion(req)
102106

103107
assert mock_engine.generate.call_args.args[1].max_tokens == 93
104108

105109
req.max_tokens = 10
106110
with suppress(Exception):
107-
asyncio.run(serving_chat.create_chat_completion(req))
111+
await serving_chat.create_chat_completion(req)
108112

109113
assert mock_engine.generate.call_args.args[1].max_tokens == 10
110114

@@ -143,23 +147,23 @@ def test_serving_chat_should_set_correct_max_tokens():
143147
)
144148

145149
with suppress(Exception):
146-
asyncio.run(serving_chat.create_chat_completion(req))
150+
await serving_chat.create_chat_completion(req)
147151

148152
assert mock_engine.generate.call_args.args[1].max_tokens == 10
149153

150154
# Test Case 2: Request's max_tokens set higher than server accepts
151155
req.max_tokens = 15
152156

153157
with suppress(Exception):
154-
asyncio.run(serving_chat.create_chat_completion(req))
158+
await serving_chat.create_chat_completion(req)
155159

156160
assert mock_engine.generate.call_args.args[1].max_tokens == 10
157161

158162
# Test Case 3: Request's max_tokens set lower than server accepts
159163
req.max_tokens = 5
160164

161165
with suppress(Exception):
162-
asyncio.run(serving_chat.create_chat_completion(req))
166+
await serving_chat.create_chat_completion(req)
163167

164168
assert mock_engine.generate.call_args.args[1].max_tokens == 5
165169

@@ -198,28 +202,29 @@ def test_serving_chat_should_set_correct_max_tokens():
198202
)
199203

200204
with suppress(Exception):
201-
asyncio.run(serving_chat.create_chat_completion(req))
205+
await serving_chat.create_chat_completion(req)
202206

203207
assert mock_engine.generate.call_args.args[1].max_tokens == 93
204208

205209
# Test Case 2: Request's max_tokens set higher than server accepts
206210
req.max_tokens = 100
207211

208212
with suppress(Exception):
209-
asyncio.run(serving_chat.create_chat_completion(req))
213+
await serving_chat.create_chat_completion(req)
210214

211215
assert mock_engine.generate.call_args.args[1].max_tokens == 93
212216

213217
# Test Case 3: Request's max_tokens set lower than server accepts
214218
req.max_tokens = 5
215219

216220
with suppress(Exception):
217-
asyncio.run(serving_chat.create_chat_completion(req))
221+
await serving_chat.create_chat_completion(req)
218222

219223
assert mock_engine.generate.call_args.args[1].max_tokens == 5
220224

221225

222-
def test_serving_chat_could_load_correct_generation_config():
226+
@pytest.mark.asyncio
227+
async def test_serving_chat_could_load_correct_generation_config():
223228

224229
mock_model_config = MockModelConfig()
225230
mock_model_config.diff_sampling_param = {
@@ -242,6 +247,7 @@ def test_serving_chat_could_load_correct_generation_config():
242247
chat_template=CHAT_TEMPLATE,
243248
chat_template_content_format="auto",
244249
request_logger=None)
250+
245251
req = ChatCompletionRequest(
246252
model=MODEL_NAME,
247253
messages=[{
@@ -252,7 +258,7 @@ def test_serving_chat_could_load_correct_generation_config():
252258
)
253259

254260
with suppress(Exception):
255-
asyncio.run(serving_chat.create_chat_completion(req))
261+
await serving_chat.create_chat_completion(req)
256262

257263
assert mock_engine.generate.call_args.args[1].temperature == 0.5
258264
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@@ -261,7 +267,7 @@ def test_serving_chat_could_load_correct_generation_config():
261267
req.temperature = 0.1
262268

263269
with suppress(Exception):
264-
asyncio.run(serving_chat.create_chat_completion(req))
270+
await serving_chat.create_chat_completion(req)
265271

266272
assert mock_engine.generate.call_args.args[1].temperature == 0.1
267273
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@@ -270,7 +276,7 @@ def test_serving_chat_could_load_correct_generation_config():
270276
req.temperature = 0.0
271277

272278
with suppress(Exception):
273-
asyncio.run(serving_chat.create_chat_completion(req))
279+
await serving_chat.create_chat_completion(req)
274280

275281
assert mock_engine.generate.call_args.args[1].temperature == 0.0
276282
assert mock_engine.generate.call_args.args[1].repetition_penalty == 1.05
@@ -309,11 +315,11 @@ async def test_serving_chat_did_set_correct_cache_salt(model_type):
309315

310316
# By default cache_salt in the engine prompt is not set
311317
with suppress(Exception):
312-
asyncio.run(serving_chat.create_chat_completion(req))
318+
await serving_chat.create_chat_completion(req)
313319
assert "cache_salt" not in mock_engine.generate.call_args.args[0]
314320

315321
# Test with certain cache_salt
316322
req.cache_salt = "test_salt"
317323
with suppress(Exception):
318-
asyncio.run(serving_chat.create_chat_completion(req))
319-
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"
324+
await serving_chat.create_chat_completion(req)
325+
assert mock_engine.generate.call_args.args[0]["cache_salt"] == "test_salt"

vllm/attention/ops/prefix_prefill.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,8 @@ def _fwd_kernel(Q,
8484
num_unroll_cache: tl.constexpr,
8585
num_unroll_request: tl.constexpr,
8686
SKIP_DECODE: tl.constexpr,
87-
USE_FP8: tl.constexpr,
8887
USE_SINKS: tl.constexpr,
88+
USE_FP8: tl.constexpr,
8989
MAX_Q_LEN: tl.constexpr = 0,
9090
MAX_CTX_LEN: tl.constexpr = 0,
9191
FP8_MIN: tl.constexpr = float8_info.min,

vllm/attention/ops/triton_unified_attention.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from vllm.logger import init_logger
1313
from vllm.platforms import current_platform
1414
from vllm.triton_utils import tl, triton
15-
from vllm.platforms import current_platform
1615

1716
logger = init_logger(__name__)
1817
float8_info = torch.finfo(current_platform.fp8_dtype())
@@ -667,8 +666,8 @@ def unified_attention(
667666
k_descale,
668667
v_descale,
669668
alibi_slopes=None,
670-
qq_bias=None,
671669
output_scale=None,
670+
qq_bias=None,
672671
# Optional tensor for sinks
673672
sinks=None,
674673
):

0 commit comments

Comments
 (0)