diff --git a/repro.py b/repro.py new file mode 100644 index 000000000..cc4e4d6c4 --- /dev/null +++ b/repro.py @@ -0,0 +1,59 @@ +import torch +import triton +from liger_kernel.transformers.rope import liger_rotary_pos_emb +from transformers.models.llama.modeling_llama import ( + apply_rotary_pos_emb, + LlamaRotaryEmbedding, +) + +num_q_heads, num_kv_heads = 32, 32 +hidden_dim = 128 +dtype = torch.bfloat16 + + +def prepare_input(batch_size, seq_length): + rotary_emb = LlamaRotaryEmbedding(hidden_dim, device="cuda") + q = torch.randn( + (batch_size, num_q_heads, seq_length, hidden_dim), + device="cuda", + requires_grad=True, + dtype=dtype, + ) + k = torch.randn( + (batch_size, num_kv_heads, seq_length, hidden_dim), + device="cuda", + requires_grad=True, + dtype=dtype, + ) + dq, dk = ( + torch.randn_like(q, device="cuda", dtype=dtype), + torch.randn_like(k, device="cuda"), + ) + pos_ids = torch.arange(seq_length, device="cuda", dtype=torch.long).unsqueeze(0) + cos, sin = rotary_emb(k, pos_ids) + return q, k, cos, sin, pos_ids + + +def liger_rotary_pos_emb_kernel(batch_size, seq_length): + q, k, cos, sin, pos_ids = prepare_input(batch_size, seq_length) + return lambda: liger_rotary_pos_emb(q, k, cos, sin, pos_ids) + + +def inductor_rotary_pos_emb_full_op(batch_size, seq_length): + q, k, cos, sin, pos_ids = prepare_input(batch_size, seq_length) + get_rotary_embedding = LlamaRotaryEmbedding(hidden_dim, device="cuda") + cos, sin = get_rotary_embedding(k, pos_ids) + compiled_func = torch.compile( + apply_rotary_pos_emb, mode="max-autotune-no-cudagraphs" + ) + return lambda: compiled_func(q, k, cos, sin, pos_ids) + + +compiled_fn = inductor_rotary_pos_emb_full_op(2, 1) +liger_kernel = liger_rotary_pos_emb_kernel(2, 1) + + +compiler_latency = triton.testing.do_bench_cudagraph(compiled_fn) +liger_latency = triton.testing.do_bench_cudagraph(liger_kernel) + +print("compiler_latency:", compiler_latency, "liger_latency:", liger_latency) diff --git a/tritonbench/operators/rope/operator.py b/tritonbench/operators/rope/operator.py index e7a0df73f..9a86f6a7e 100644 --- a/tritonbench/operators/rope/operator.py +++ b/tritonbench/operators/rope/operator.py @@ -31,39 +31,31 @@ def __init__( self.baseline_op = None self.liger_op = None self.num_q_heads = 32 - self.num_kv_heads = 8 + self.num_kv_heads = 32 # should be 8 + self.hidden_dim = 128 def get_input_iter(self) -> Generator: - hidden_size = 8192 - for seq_length in [2**i for i in range(10, 15)]: - yield hidden_size, seq_length + batch_size = 1 + for seq_length in [2**i for i in range(0, 19)]: + yield batch_size, seq_length - seq_length = 2048 - for hidden_size in [32 * (2**i) for i in range(4, 10, 2)]: - yield hidden_size, seq_length + seq_length = 1 + for batch_size in [2**i for i in range(0, 11)]: + yield batch_size, seq_length - def prepare_input(self, hidden_size, seq_length): - head_dim = hidden_size // self.num_q_heads - rotary_emb = LlamaRotaryEmbedding(head_dim, device=self.device) - q = ( - torch.randn( - (1, seq_length, self.num_q_heads, head_dim), - device=self.device, - requires_grad=True, - dtype=self.dtype, - ) - .transpose(1, 2) - .contiguous() + def prepare_input(self, batch_size, seq_length): + rotary_emb = LlamaRotaryEmbedding(self.hidden_dim, device=self.device) + q = torch.randn( + (batch_size, self.num_q_heads, seq_length, self.hidden_dim), + device=self.device, + requires_grad=True, + dtype=self.dtype, ) - k = ( - torch.randn( - (1, seq_length, self.num_kv_heads, head_dim), - device=self.device, - requires_grad=True, - dtype=self.dtype, - ) - .transpose(1, 2) - .contiguous() + k = torch.randn( + (batch_size, self.num_kv_heads, seq_length, self.hidden_dim), + device=self.device, + requires_grad=True, + dtype=self.dtype, ) dq, dk = ( torch.randn_like(q, device=self.device, dtype=self.dtype), @@ -82,25 +74,26 @@ def prepare_input(self, hidden_size, seq_length): return q, k, cos, sin, pos_ids @register_benchmark(baseline=True) - def apply_rotary_pos_emb(self, hidden_size, seq_length) -> Callable: - q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length) + def apply_rotary_pos_emb(self, batch_size, seq_length) -> Callable: + q, k, cos, sin, pos_ids = self.prepare_input(batch_size, seq_length) return lambda: apply_rotary_pos_emb(q, k, cos, sin, pos_ids) @register_benchmark() - def liger_rotary_pos_emb(self, hidden_size, seq_length) -> Callable: - q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length) + def liger_rotary_pos_emb(self, batch_size, seq_length) -> Callable: + q, k, cos, sin, pos_ids = self.prepare_input(batch_size, seq_length) return lambda: liger_rotary_pos_emb(q, k, cos, sin, pos_ids) @register_benchmark() - def inductor_rotary_pos_emb_full_op(self, hidden_size, seq_length) -> Callable: - q, k, cos, sin, pos_ids = self.prepare_input(hidden_size, seq_length) - head_dim = hidden_size // self.num_q_heads - compiled = torch.compile(LlamaRotaryEmbedding(head_dim, device=self.device)) - cos, sin = compiled(k, pos_ids) - compiled_func = torch.compile(apply_rotary_pos_emb) + def inductor_rotary_pos_emb_full_op(self, batch_size, seq_length) -> Callable: + q, k, cos, sin, pos_ids = self.prepare_input(batch_size, seq_length) + get_rotary_embedding = LlamaRotaryEmbedding(self.hidden_dim, device=self.device) + cos, sin = get_rotary_embedding(k, pos_ids) + compiled_func = torch.compile( + apply_rotary_pos_emb, mode="max-autotune-no-cudagraphs" + ) return lambda: compiled_func(q, k, cos, sin, pos_ids) - @register_x_val(label="(H, T)") + @register_x_val(label="(B, S)") def get_x_val(self, example_inputs) -> Tuple[int, int]: return (example_inputs[0], example_inputs[1])