Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions fast_llm/functional/triton/rotary.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@ def triton_rotary_kernel(
num_heads: tl_constexpr,
rotary_block_size: tl_constexpr,
head_block_size: tl_constexpr,
seq_len: tl_constexpr,
backward: tl_constexpr,
):
# TODO: Int64 ptr if needed?
pid_0 = tl.program_id(axis=0)
pid_1 = tl.program_id(axis=1)
pid_2 = tl.program_id(axis=2)
pid_0_1 = tl.program_id(axis=0) # Folded (batch * seq) index
pid_2 = tl.program_id(axis=1) # Head index
pid_0 = pid_0_1 // seq_len
pid_1 = pid_0_1 - pid_0 * seq_len

offsets = tl.arange(0, rotary_block_size)
head_offsets = pid_2 * head_block_size + tl.arange(0, head_block_size)[:, None]
Expand Down Expand Up @@ -76,7 +78,8 @@ def triton_rotary_(
if head_block_size > num_heads:
head_block_size = triton.next_power_of_2(num_heads)

triton_rotary_kernel[(batch_size, seq_len, triton.cdiv(num_heads, head_block_size))](
# Folded the large y dim into the x dim as gridDim.x is 32 bit while gridDim.y and gridDim.z are 16 bit registers
triton_rotary_kernel[(batch_size * seq_len, triton.cdiv(num_heads, head_block_size))](
input_,
frequencies,
input_.stride(0),
Expand All @@ -86,6 +89,7 @@ def triton_rotary_(
num_heads,
rotary_block_size,
head_block_size,
seq_len,
backward, # noqa
)
return input_
Expand Down