diff --git a/fast_llm/functional/triton/rotary.py b/fast_llm/functional/triton/rotary.py index 2cea4c6d0..76e547657 100644 --- a/fast_llm/functional/triton/rotary.py +++ b/fast_llm/functional/triton/rotary.py @@ -17,12 +17,14 @@ def triton_rotary_kernel( num_heads: tl_constexpr, rotary_block_size: tl_constexpr, head_block_size: tl_constexpr, + seq_len: tl_constexpr, backward: tl_constexpr, ): # TODO: Int64 ptr if needed? - pid_0 = tl.program_id(axis=0) - pid_1 = tl.program_id(axis=1) - pid_2 = tl.program_id(axis=2) + pid_0_1 = tl.program_id(axis=0) # Folded (batch * seq) index + pid_2 = tl.program_id(axis=1) # Head index + pid_0 = pid_0_1 // seq_len + pid_1 = pid_0_1 - pid_0 * seq_len offsets = tl.arange(0, rotary_block_size) head_offsets = pid_2 * head_block_size + tl.arange(0, head_block_size)[:, None] @@ -76,7 +78,8 @@ def triton_rotary_( if head_block_size > num_heads: head_block_size = triton.next_power_of_2(num_heads) - triton_rotary_kernel[(batch_size, seq_len, triton.cdiv(num_heads, head_block_size))]( + # Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers + triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))]( input_, frequencies, input_.stride(0), @@ -86,6 +89,7 @@ def triton_rotary_( num_heads, rotary_block_size, head_block_size, + seq_len, backward, # noqa ) return input_