10
10
11
11
# static_shapes=True gives a performance boost for matmuls
12
12
@helion .kernel (static_shapes = True )
13
- def matmul_split_k (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
13
+ def matmul_split_k_no_bias (x : torch .Tensor , y : torch .Tensor ) -> torch .Tensor :
14
14
m , k = x .size ()
15
15
k2 , n = y .size ()
16
16
assert k == k2 , f"size mismatch { k } != { k2 } "
@@ -27,10 +27,46 @@ def matmul_split_k(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
27
27
return out
28
28
29
29
30
+ @helion .kernel (static_shapes = True )
31
+ def matmul_split_k_with_bias (x : torch .Tensor , y : torch .Tensor , bias : torch .Tensor ) -> torch .Tensor :
32
+ m , k = x .size ()
33
+ k2 , n = y .size ()
34
+ assert k == k2 , f"size mismatch { k } != { k2 } "
35
+ bias_size = bias .size (0 )
36
+ assert bias_size == n , f"bias size mismatch, expected { n } , got { bias_size } "
37
+
38
+ # Initialize output with bias instead of zeros
39
+ out = bias .expand (m , n ).contiguous ()
40
+
41
+ split_k = hl .register_tunable ("split_k" , PowerOfTwoFragment (1 , 256 ))
42
+ k_block = helion .next_power_of_2 (helion .cdiv (k , split_k ))
43
+ for tile_m , tile_n , outer_k in hl .tile ([m , n , k ], block_size = [None , None , k_block ]):
44
+ acc = hl .zeros ([tile_m , tile_n ], dtype = torch .float32 )
45
+ for inner_k in hl .tile (outer_k .begin , outer_k .end ):
46
+ acc = torch .addmm (acc , x [tile_m , inner_k ], y [inner_k , tile_n ])
47
+ hl .atomic_add (out , [tile_m , tile_n ], acc )
48
+ return out
49
+
50
+
51
+ def matmul_split_k (x : torch .Tensor , y : torch .Tensor , bias : torch .Tensor = None ) -> torch .Tensor :
52
+ """Wrapper function for tritonbench that dispatches based on bias presence."""
53
+ if bias is None :
54
+ return matmul_split_k_no_bias (x , y )
55
+ else :
56
+ return matmul_split_k_with_bias (x , y , bias )
57
+
58
+
30
59
def check (m : int , k : int , n : int ) -> None :
31
60
x = torch .randn ([m , k ], device = "cuda" , dtype = torch .float16 )
32
61
y = torch .randn ([k , n ], device = "cuda" , dtype = torch .float16 )
33
- run_example (matmul_split_k , torch .matmul , (x , y ), atol = 1 )
62
+
63
+ # Test without bias
64
+ run_example (matmul_split_k_no_bias , torch .matmul , (x , y ), atol = 1 )
65
+
66
+ # Test with bias
67
+ bias = torch .randn ([n ], device = "cuda" , dtype = torch .float16 )
68
+ expected_with_bias = lambda x , y , bias : torch .matmul (x , y ) + bias
69
+ run_example (matmul_split_k_with_bias , expected_with_bias , (x , y , bias ), atol = 1 )
34
70
35
71
36
72
def main () -> None :
0 commit comments