@@ -291,8 +291,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
291291 float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
292292 scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
293293 OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
294- int max_ctx_blocks, const float * k_scale, const float * v_scale,
295- const float * __restrict__ fp8_out_scale_ptr) {
294+ int max_ctx_blocks, const float * k_scale, const float * v_scale) {
296295 // clang-format on
297296 constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
298297 const auto warpid = threadIdx .x / WARP_SIZE;
@@ -806,8 +805,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
806805 float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
807806 scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
808807 OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
809- int max_ctx_blocks, const float * k_scale, const float * v_scale,
810- const float * __restrict__ fp8_out_scale_ptr) {
808+ int max_ctx_blocks, const float * k_scale, const float * v_scale) {
811809 // clang-format on
812810 constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
813811 const auto warpid = threadIdx .x / WARP_SIZE;
@@ -1249,8 +1247,6 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
12491247
12501248 // final write to tmp_out after vout accumulation
12511249 if (warpid == 0 ) {
1252- const float out_scale =
1253- (fp8_out_scale_ptr != nullptr ) ? 1 .0f / (*fp8_out_scale_ptr) : 1 .0f ;
12541250 _B16x4 vout[QHLOOP][VHELOOP];
12551251 // iterate across heads
12561252 for (int qh = 0 ; qh < QHLOOP; qh++) {
@@ -3019,8 +3015,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
30193015 float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
30203016 scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
30213017 OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
3022- int max_ctx_blocks, const float * k_scale, const float * v_scale,
3023- const float * __restrict__ fp8_out_scale_ptr) {
3018+ int max_ctx_blocks, const float * k_scale, const float * v_scale) {
30243019 UNREACHABLE_CODE
30253020}
30263021
@@ -3047,8 +3042,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
30473042 float * __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
30483043 scalar_t * __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
30493044 OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
3050- int max_ctx_blocks, const float * k_scale, const float * v_scale,
3051- const float * __restrict__ fp8_out_scale_ptr) {
3045+ int max_ctx_blocks, const float * k_scale, const float * v_scale) {
30523046 UNREACHABLE_CODE
30533047}
30543048
@@ -3079,7 +3073,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
30793073 block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
30803074 max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
30813075 kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
3082- max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr );
3076+ max_ctx_blocks, k_scale_ptr, v_scale_ptr);
30833077
30843078#define LAUNCH_CUSTOM_ATTENTION_MFMA4 (GQA_RATIO ) \
30853079 paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
@@ -3090,7 +3084,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
30903084 block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
30913085 max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
30923086 kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
3093- max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr );
3087+ max_ctx_blocks, k_scale_ptr, v_scale_ptr);
30943088
30953089#define LAUNCH_CUSTOM_REDUCTION (NPAR_LOOPS ) \
30963090 paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \
0 commit comments