Skip to content

Commit 13dede3

Browse files
yyihuangBoyuanFeng
authored andcommitted
minor: zero workspace buffer init for flashinfer trtllm-gen attn (vllm-project#22603)
Signed-off-by: Boyuan Feng <[email protected]>
1 parent 13757ba commit 13dede3

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

tests/kernels/attention/test_flashinfer_trtllm_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_flashinfer_trtllm_decode_with_baseline(
113113
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
114114
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
115115

116-
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
116+
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
117117
wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
118118
workspace_buffer,
119119
kv_layout,
@@ -247,7 +247,7 @@ def test_flashinfer_trtllm_prefill_with_baseline(
247247
kv_indices = torch.tensor(kv_indices, dtype=torch.int32)
248248
kv_last_page_lens = torch.tensor(kv_last_page_lens, dtype=torch.int32)
249249

250-
workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8)
250+
workspace_buffer = torch.zeros(128 * 1024 * 1024, dtype=torch.int8)
251251
wrapper = flashinfer.BatchPrefillWithPagedKVCacheWrapper(
252252
workspace_buffer, kv_layout)
253253
wrapper.plan(q_indptr,

vllm/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def __init__(self, runner):
203203

204204
def _get_workspace_buffer(self):
205205
if self._workspace_buffer is None:
206-
self._workspace_buffer = torch.empty(
206+
self._workspace_buffer = torch.zeros(
207207
FLASHINFER_WORKSPACE_BUFFER_SIZE,
208208
dtype=torch.uint8,
209209
device=self.runner.device)

vllm/v1/attention/backends/flashinfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ def __init__(self, kv_cache_spec: AttentionSpec, layer_names: list[str],
252252

253253
def _get_workspace_buffer(self):
254254
if self._workspace_buffer is None:
255-
self._workspace_buffer = torch.empty(
255+
self._workspace_buffer = torch.zeros(
256256
FLASHINFER_WORKSPACE_BUFFER_SIZE,
257257
dtype=torch.uint8,
258258
device=self.device)

0 commit comments

Comments
 (0)