Skip to content

Commit 307d8bc

Browse files
authored
Removing redundant parameters from the MIs side and fixing Navi build (ROCm#559)
Signed-off-by: Gregory Shtrasberg <[email protected]>
1 parent 1900335 commit 307d8bc

File tree

1 file changed

+6
-12
lines changed

1 file changed

+6
-12
lines changed

csrc/rocm/attention.cu

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -291,8 +291,7 @@ __launch_bounds__(NUM_THREADS, 5) void paged_attention_ll4mi_QKV_mfma16_kernel(
291291
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
292292
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
293293
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
294-
int max_ctx_blocks, const float* k_scale, const float* v_scale,
295-
const float* __restrict__ fp8_out_scale_ptr) {
294+
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
296295
// clang-format on
297296
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
298297
const auto warpid = threadIdx.x / WARP_SIZE;
@@ -806,8 +805,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
806805
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
807806
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
808807
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
809-
int max_ctx_blocks, const float* k_scale, const float* v_scale,
810-
const float* __restrict__ fp8_out_scale_ptr) {
808+
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
811809
// clang-format on
812810
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
813811
const auto warpid = threadIdx.x / WARP_SIZE;
@@ -1249,8 +1247,6 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
12491247

12501248
// final write to tmp_out after vout accumulation
12511249
if (warpid == 0) {
1252-
const float out_scale =
1253-
(fp8_out_scale_ptr != nullptr) ? 1.0f / (*fp8_out_scale_ptr) : 1.0f;
12541250
_B16x4 vout[QHLOOP][VHELOOP];
12551251
// iterate across heads
12561252
for (int qh = 0; qh < QHLOOP; qh++) {
@@ -3019,8 +3015,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma16_kernel(
30193015
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
30203016
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
30213017
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
3022-
int max_ctx_blocks, const float* k_scale, const float* v_scale,
3023-
const float* __restrict__ fp8_out_scale_ptr) {
3018+
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
30243019
UNREACHABLE_CODE
30253020
}
30263021

@@ -3047,8 +3042,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_mfma4_kernel(
30473042
float* __restrict__ max_logits, // [num_seqs, num_heads, max_num_partitions]
30483043
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions, head_size]
30493044
OUTT* __restrict__ final_out, // [num_seqs, num_heads, head_size]
3050-
int max_ctx_blocks, const float* k_scale, const float* v_scale,
3051-
const float* __restrict__ fp8_out_scale_ptr) {
3045+
int max_ctx_blocks, const float* k_scale, const float* v_scale) {
30523046
UNREACHABLE_CODE
30533047
}
30543048

@@ -3079,7 +3073,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
30793073
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
30803074
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
30813075
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
3082-
max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr);
3076+
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
30833077

30843078
#define LAUNCH_CUSTOM_ATTENTION_MFMA4(GQA_RATIO) \
30853079
paged_attention_ll4mi_QKV_mfma4_kernel<T, KVT, KV_DTYPE, OUTT, BLOCK_SIZE, \
@@ -3090,7 +3084,7 @@ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
30903084
block_tables_ptr, context_lens_ptr, query_start_loc_ptr, \
30913085
max_num_blocks_per_seq, alibi_slopes_ptr, q_stride, kv_block_stride, \
30923086
kv_head_stride, exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, \
3093-
max_ctx_blocks, k_scale_ptr, v_scale_ptr, fp8_out_scale_ptr);
3087+
max_ctx_blocks, k_scale_ptr, v_scale_ptr);
30943088

30953089
#define LAUNCH_CUSTOM_REDUCTION(NPAR_LOOPS) \
30963090
paged_attention_ll4mi_reduce_kernel<T, OUTT, HEAD_SIZE, HEAD_SIZE, \

0 commit comments

Comments
 (0)