Skip to content

Commit 143844f

Browse files
authored
[XPU]Fix xpu spec decoding UTs, avoid using cuda graph (vllm-project#25847)
Signed-off-by: Kunshang Ji <[email protected]>
1 parent 219cfbe commit 143844f

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

.buildkite/scripts/hardware_ci/run-xpu-test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ docker run \
4242
pytest -v -s v1/sample --ignore=v1/sample/test_logprobs.py --ignore=v1/sample/test_logprobs_e2e.py
4343
pytest -v -s v1/worker --ignore=v1/worker/test_gpu_model_runner.py
4444
pytest -v -s v1/structured_output
45-
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_eagle.py --ignore=v1/spec_decode/test_tree_attention.py
45+
pytest -v -s v1/spec_decode --ignore=v1/spec_decode/test_max_len.py --ignore=v1/spec_decode/test_tree_attention.py
4646
pytest -v -s v1/kv_connector/unit --ignore=v1/kv_connector/unit/test_multi_connector.py --ignore=v1/kv_connector/unit/test_nixl_connector.py --ignore=v1/kv_connector/unit/test_shared_storage_connector.py
4747
pytest -v -s v1/test_metrics
4848
pytest -v -s v1/test_serial_utils.py

tests/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,8 @@ def get_attn_backend_list_based_on_platform() -> list[str]:
11431143
print("Skip FLASH_ATTN on ROCm as aiter is not installed")
11441144

11451145
return attn_backend_list
1146+
elif current_platform.is_xpu():
1147+
return ["FLASH_ATTN", "TRITON_ATTN"]
11461148
else:
11471149
raise ValueError("Unsupported platform")
11481150

vllm/v1/spec_decode/eagle.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,13 @@ def __init__(
7272

7373
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
7474

75-
self.use_cuda_graph = (self.vllm_config.compilation_config.level
75+
self.use_cuda_graph = (not current_platform.is_xpu()
76+
and self.vllm_config.compilation_config.level
7677
== CompilationLevel.PIECEWISE and
7778
not self.vllm_config.model_config.enforce_eager)
7879
self.cudagraph_batch_sizes = list(
79-
reversed(
80-
self.vllm_config.compilation_config.cudagraph_capture_sizes))
80+
reversed(self.vllm_config.compilation_config.
81+
cudagraph_capture_sizes)) if self.use_cuda_graph else []
8182

8283
# persistent buffers for cuda graph
8384
self.input_ids = torch.zeros(self.max_num_tokens,

0 commit comments

Comments
 (0)