@@ -242,8 +242,9 @@ def is_rocm_aiter_fp8bmm_enabled() -> bool:
242
242
243
243
244
244
if is_rocm_aiter_fp8bmm_enabled ():
245
- from aiter .ops .triton .batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501
246
- batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant as aiter_triton_fp8_bmm )
245
+ from aiter .ops .triton .batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant import ( # noqa: E501 # isort: skip
246
+ batched_gemm_a8w8_a_per_token_group_prequant_w_per_batched_tensor_quant
247
+ as aiter_triton_fp8_bmm )
247
248
248
249
def dynamic_per_batched_tensor_quant (
249
250
x : torch .Tensor , dtype : torch .dtype = torch .float8_e4m3fn ):
@@ -1042,29 +1043,6 @@ def get_and_maybe_dequant_weights(layer: LinearBase):
1042
1043
W_K , dtype = current_platform .fp8_dtype ())
1043
1044
self .W_V , self .W_V_scale = dynamic_per_batched_tensor_quant (
1044
1045
W_V , dtype = current_platform .fp8_dtype ())
1045
- logger .info_once (
1046
- "[Aiter Triton] compiling fp8 BMM for batch sizes 1 to 128 "
1047
- f"W_K shape = { list (self .W_K .shape )} and "
1048
- f"W_V shape = { list (self .W_V .shape )} " )
1049
- for m in range (1 , 129 ):
1050
- x = torch .empty ((self .W_K .shape [0 ], m , self .W_K .shape [2 ]),
1051
- dtype = torch .bfloat16 ,
1052
- device = self .W_K .device )
1053
- aiter_triton_fp8_bmm (x ,
1054
- self .W_K ,
1055
- self .W_K_scale ,
1056
- group_size = 128 ,
1057
- transpose_bm = True )
1058
-
1059
- x = torch .empty ((self .W_V .shape [0 ], m , self .W_V .shape [2 ]),
1060
- dtype = torch .bfloat16 ,
1061
- device = self .W_V .device )
1062
- aiter_triton_fp8_bmm (x ,
1063
- self .W_V ,
1064
- self .W_V_scale ,
1065
- group_size = 128 ,
1066
- transpose_bm = True )
1067
-
1068
1046
else :
1069
1047
# Convert from (L, N, V) to (N, L, V)
1070
1048
self .W_UV = W_UV .transpose (0 , 1 )
0 commit comments