Skip to content

Commit bc1713c

Browse files
committed
fix - fix multi bs core
1 parent a9f456e commit bc1713c

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

rtp_llm/models_py/modules/factory/attention/cuda_mla_impl/flashinfer_mla_wrapper.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,11 @@ def prepare(self, attn_inputs: PyAttentionInputs, use_cuda_graph: bool = False):
215215

216216
def prepare_cuda_graph(self, attn_inputs: PyAttentionInputs):
217217
self.fmha_impl.cuda_graph_kv_indices = torch.empty(
218-
(self.bs * self.max_context_len // self.seq_size_per_block),
218+
(
219+
(self.max_context_len + self.seq_size_per_block - 1)
220+
// self.seq_size_per_block
221+
)
222+
* self.bs,
219223
dtype=torch.int32,
220224
device="cuda",
221225
)

0 commit comments

Comments
 (0)