@@ -14,24 +14,32 @@ def init_func(nargs):
1414# NOTE: will need to warm up kernels each time, triton autotune caching isn't a thing right now
1515
1616configs = [
17+ triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 128 }, num_warps = 2 , pre_hook = init_to_zero ("Y" )),
18+
19+ triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 64 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
1720 triton .Config ({"BLOCK_M" : 8 , "BLOCK_N" : 128 }, num_warps = 2 , pre_hook = init_to_zero ("Y" )),
1821 triton .Config ({"BLOCK_M" : 16 , "BLOCK_N" : 256 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
1922 triton .Config ({"BLOCK_M" : 16 , "BLOCK_N" : 256 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
20- triton .Config ({"BLOCK_M" : 16 , "BLOCK_N" : 512 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
21- #triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
2223 triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 256 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
23- triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 512 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
24- #triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
2524 triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 256 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
26- triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 512 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
27- #triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
2825 triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 16 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
2926 triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 32 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
3027 triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 64 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
3128 triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 128 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
3229 triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 256 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
33- # triton.Config({"BLOCK_M": 128, "BLOCK_N": 512}, num_warps=4, pre_hook=init_to_zero("Y")),
34- #triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
30+
31+ triton .Config ({"BLOCK_M" : 128 , "BLOCK_N" : 512 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
32+ triton .Config ({"BLOCK_M" : 64 , "BLOCK_N" : 512 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
33+ triton .Config ({"BLOCK_M" : 32 , "BLOCK_N" : 512 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
34+ triton .Config ({"BLOCK_M" : 16 , "BLOCK_N" : 512 }, num_warps = 4 , pre_hook = init_to_zero ("Y" )),
35+
36+
37+ # Llama 3 variants can use BLOCK_N >= 1024
38+ # triton.Config({"BLOCK_M": 128, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
39+ # triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
40+ # triton.Config({"BLOCK_M": 64, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
41+ # triton.Config({"BLOCK_M": 32, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
42+ # triton.Config({"BLOCK_M": 16, "BLOCK_N": 1024}, num_warps=4, pre_hook=init_to_zero("Y")),
3543]
3644
3745@triton .autotune (
@@ -287,7 +295,6 @@ def forward(
287295 sparsity_bin : int ,
288296 kv_size : int
289297 ) -> torch .Tensor :
290- return torch .matmul (x , weight .T )
291298 return qkv_gemv (x , weight , threshold_q , threshold_k , threshold_v , sparsity_bin , kv_size ) if x .shape [1 ] == 1 else torch .matmul (x , weight .T )
292299
293300# for testing purposes, to see if overhead at 0% is really due to strengthening torch.matmul (seems like it is)
0 commit comments