Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions repro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
import torch
import triton
from liger_kernel.transformers.rope import liger_rotary_pos_emb
from transformers.models.llama.modeling_llama import (
apply_rotary_pos_emb,
LlamaRotaryEmbedding,
)

num_q_heads, num_kv_heads = 32, 32
hidden_dim = 128
dtype = torch.bfloat16


def prepare_input(batch_size, seq_length):
rotary_emb = LlamaRotaryEmbedding(hidden_dim, device="cuda")
q = torch.randn(
(batch_size, num_q_heads, seq_length, hidden_dim),
device="cuda",
requires_grad=True,
dtype=dtype,
)
k = torch.randn(
(batch_size, num_kv_heads, seq_length, hidden_dim),
device="cuda",
requires_grad=True,
dtype=dtype,
)
dq, dk = (
torch.randn_like(q, device="cuda", dtype=dtype),
torch.randn_like(k, device="cuda"),
)
pos_ids = torch.arange(seq_length, device="cuda", dtype=torch.long).unsqueeze(0)
cos, sin = rotary_emb(k, pos_ids)
return q, k, cos, sin, pos_ids


def liger_rotary_pos_emb_kernel(batch_size, seq_length):
q, k, cos, sin, pos_ids = prepare_input(batch_size, seq_length)
return lambda: liger_rotary_pos_emb(q, k, cos, sin, pos_ids)


def inductor_rotary_pos_emb_full_op(batch_size, seq_length):
q, k, cos, sin, pos_ids = prepare_input(batch_size, seq_length)
get_rotary_embedding = LlamaRotaryEmbedding(hidden_dim, device="cuda")
cos, sin = get_rotary_embedding(k, pos_ids)
compiled_func = torch.compile(
apply_rotary_pos_emb, mode="max-autotune-no-cudagraphs"
)
return lambda: compiled_func(q, k, cos, sin, pos_ids)


compiled_fn = inductor_rotary_pos_emb_full_op(2, 1)
liger_kernel = liger_rotary_pos_emb_kernel(2, 1)


compiler_latency = triton.testing.do_bench_cudagraph(compiled_fn)
liger_latency = triton.testing.do_bench_cudagraph(liger_kernel)

print("compiler_latency:", compiler_latency, "liger_latency:", liger_latency)
71 changes: 32 additions & 39 deletions tritonbench/operators/rope/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,39 +31,31 @@ def __init__(
self.baseline_op = None
self.liger_op = None
self.num_q_heads = 32
self.num_kv_heads = 8
self.num_kv_heads = 32 # should be 8
self.hidden_dim = 128

def get_input_iter(self) -> Generator:
hidden_size = 8192
for seq_length in [2**i for i in range(10, 15)]:
yield hidden_size, seq_length
batch_size = 1
for seq_length in [2**i for i in range(0, 19)]:
yield batch_size, seq_length

seq_length = 2048
for hidden_size in [32 * (2**i) for i in range(4, 10, 2)]:
yield hidden_size, seq_length
seq_length = 1
for batch_size in [2**i for i in range(0, 11)]:
yield batch_size, seq_length

def prepare_input(self, hidden_size, seq_length):
head_dim = hidden_size // self.num_q_heads
rotary_emb = LlamaRotaryEmbedding(head_dim, device=self.device)
q = (
torch.randn(
(1, seq_length, self.num_q_heads, head_dim),
device=self.device,
requires_grad=True,
dtype=self.dtype,
)
.transpose(1, 2)
.contiguous()
def prepare_input(self, batch_size, seq_length):
rotary_emb = LlamaRotaryEmbedding(self.hidden_dim, device=self.device)
q = torch.randn(
(batch_size, self.num_q_heads, seq_length, self.hidden_dim),
device=self.device,
requires_grad=True,
dtype=self.dtype,
)
k = (
torch.randn(
(1, seq_length, self.num_kv_heads, head_dim),
device=self.device,
requires_grad=True,
dtype=self.dtype,
)
.transpose(1, 2)
.contiguous()
k = torch.randn(
(batch_size, self.num_kv_heads, seq_length, self.hidden_dim),
device=self.device,
requires_grad=True,
dtype=self.dtype,
)
dq, dk = (
torch.randn_like(q, device=self.device, dtype=self.dtype),
Expand All @@ -82,25 +74,26 @@ def prepare_input(self, hidden_size, seq_length):
return q, k, cos, sin, pos_ids

@register_benchmark(baseline=True)
def apply_rotary_pos_emb(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
def apply_rotary_pos_emb(self, batch_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(batch_size, seq_length)
return lambda: apply_rotary_pos_emb(q, k, cos, sin, pos_ids)

@register_benchmark()
def liger_rotary_pos_emb(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
def liger_rotary_pos_emb(self, batch_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(batch_size, seq_length)
return lambda: liger_rotary_pos_emb(q, k, cos, sin, pos_ids)

@register_benchmark()
def inductor_rotary_pos_emb_full_op(self, hidden_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length)
head_dim = hidden_size // self.num_q_heads
compiled = torch.compile(LlamaRotaryEmbedding(head_dim, device=self.device))
cos, sin = compiled(k, pos_ids)
compiled_func = torch.compile(apply_rotary_pos_emb)
def inductor_rotary_pos_emb_full_op(self, batch_size, seq_length) -> Callable:
q, k, cos, sin, pos_ids = self.prepare_input(batch_size, seq_length)
get_rotary_embedding = LlamaRotaryEmbedding(self.hidden_dim, device=self.device)
cos, sin = get_rotary_embedding(k, pos_ids)
compiled_func = torch.compile(
apply_rotary_pos_emb, mode="max-autotune-no-cudagraphs"
)
return lambda: compiled_func(q, k, cos, sin, pos_ids)

@register_x_val(label="(H, T)")
@register_x_val(label="(B, S)")
def get_x_val(self, example_inputs) -> Tuple[int, int]:
return (example_inputs[0], example_inputs[1])

Expand Down
Loading