Skip to content

Commit 8017d7d

Browse files
committed
rm kernel warmup
Signed-off-by: Divakar Verma <[email protected]>
1 parent c219220 commit 8017d7d

File tree

1 file changed

+3
-25
lines changed

1 file changed

+3
-25
lines changed

vllm/v1/attention/backends/mla/common.py

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -242,8 +242,9 @@ def is_rocm_aiter_fp8bmm_enabled() -> bool:
242242

243243

244244
if is_rocm_aiter_fp8bmm_enabled():
245-
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
246-
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm)
245+
from aiter.ops.triton.batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
246+
batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
247+
as aiter_triton_fp8_bmm)
247248

248249
def dynamic_per_batched_tensor_quant(
249250
x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn):
@@ -1042,29 +1043,6 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
10421043
W_K, dtype=current_platform.fp8_dtype())
10431044
self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
10441045
W_V, dtype=current_platform.fp8_dtype())
1045-
logger.info_once(
1046-
"[Aiter Triton] compiling fp8 BMM for batch sizes 1 to 128 "
1047-
f"W_K shape = {list(self.W_K.shape)} and "
1048-
f"W_V shape = {list(self.W_V.shape)}")
1049-
for m in range(1, 129):
1050-
x = torch.empty((self.W_K.shape[0], m, self.W_K.shape[2]),
1051-
dtype=torch.bfloat16,
1052-
device=self.W_K.device)
1053-
aiter_triton_fp8_bmm(x,
1054-
self.W_K,
1055-
self.W_K_scale,
1056-
group_size=128,
1057-
transpose_bm=True)
1058-
1059-
x = torch.empty((self.W_V.shape[0], m, self.W_V.shape[2]),
1060-
dtype=torch.bfloat16,
1061-
device=self.W_V.device)
1062-
aiter_triton_fp8_bmm(x,
1063-
self.W_V,
1064-
self.W_V_scale,
1065-
group_size=128,
1066-
transpose_bm=True)
1067-
10681046
else:
10691047
# Convert from (L, N, V) to (N, L, V)
10701048
self.W_UV = W_UV.transpose(0, 1)

0 commit comments

Comments
 (0)