Skip to content

Commit 0487870

Browse files
committed
rm tuning
1 parent 054263a commit 0487870

File tree

1 file changed

+1
-17
lines changed

1 file changed

+1
-17
lines changed

lightllm/models/llama/triton_kernel/rmsnorm.py

Lines changed: 1 addition & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import triton
44
import triton.language as tl
5-
from lightllm.common.triton_utils.autotuner import autotune
65

76

87
@triton.jit
@@ -53,18 +52,7 @@ def get_test_configs():
5352
]
5453

5554

56-
def get_static_key(x):
57-
return {"N": x.shape[-1], "out_dtype": str(x.dtype)}
58-
59-
60-
@autotune(
61-
kernel_name="rms_norm_fwd_fused:v1",
62-
configs_gen_func=get_test_configs,
63-
static_key_func=get_static_key,
64-
run_key_func=lambda x: x.shape[0],
65-
mutates_args=[],
66-
)
67-
def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None, run_config=None):
55+
def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None):
6856
# allocate output
6957
y = torch.empty_like(x) if out is None else out
7058
# reshape input data into 2D tensor
@@ -84,10 +72,6 @@ def rmsnorm_forward(x: torch.Tensor, weight, eps, out=None, run_config=None):
8472
if BLOCK_SIZE > 16384:
8573
BLOCK_SIZE = 16384
8674

87-
if run_config is not None:
88-
BLOCK_SIZE = run_config["BLOCK_SIZE"]
89-
num_warps = run_config["num_warps"]
90-
9175
# enqueue kernel
9276
_rms_norm_fwd_fused[(M,)](
9377
x_arg,

0 commit comments

Comments
 (0)