diff --git a/tritonbench/operators/decoding_attention/operator.py b/tritonbench/operators/decoding_attention/operator.py index 680b0d5df..7d6a7adf7 100644 --- a/tritonbench/operators/decoding_attention/operator.py +++ b/tritonbench/operators/decoding_attention/operator.py @@ -469,11 +469,10 @@ def triton_splitk_fp8kv( ) -> Callable: _, _, num_q_heads, _ = q.shape batch_size, max_sequence_length, _, _ = k_cache.shape + k_cache = k_cache.to(torch.uint8).view(torch.int32) + v_cache = v_cache.to(torch.uint8).view(torch.int32) _q, _k, _v, attn_bias = _pack_xformer_input(q, k_cache, v_cache, cache_seqlens) - _k = _k.to(torch.uint8).view(torch.int32) - _v = _v.to(torch.uint8).view(torch.int32) - k_fp8_scales_shifts = torch.zeros( batch_size * max_sequence_length, dtype=torch.int32,