1
1
from dataclasses import dataclass
2
2
from pathlib import Path
3
- from typing import Optional
3
+ from typing import Optional , Tuple
4
4
5
5
try :
6
6
import iree .turbine .kernel as tk
24
24
from iree .compiler import ir
25
25
from iree .compiler .dialects import arith , func , linalg , tensor
26
26
27
+ kDynamic = ir .ShapedType .get_dynamic_size ()
28
+
27
29
28
30
def num_bytes (dtype : str ) -> int :
29
31
dtype_to_bytes = {
@@ -42,6 +44,7 @@ def num_bytes(dtype: str) -> int:
42
44
43
45
@dataclass
44
46
class GemmConfig :
47
+ # Note that M, N and K may be set to kDynamic, a special value
45
48
M : int
46
49
N : int
47
50
K : int
@@ -50,37 +53,62 @@ class GemmConfig:
50
53
operand_element_type : str
51
54
accumulator_element_type : str
52
55
result_element_type : str
56
+ # runtime_dim subtitutes for any dynamic dims when executing.
57
+ # TODO: It would be better if we could execute the same compiled dynamic
58
+ # kernel for a series of different sizes, rather than duplicating the
59
+ # GemmConfig. The current design's advantage is that no changes have
60
+ # to be made to the execution logic (looks just like a static shape).
61
+ runtime_dim : Optional [int ] = None
53
62
54
63
def get_name (self ) -> str :
55
- name = f"gemm_{ self .M } _{ self .N } _{ self .K } _{ self .operand_element_type } _{ self .accumulator_element_type } "
64
+ M = self .M if self .M != kDynamic else "D"
65
+ N = self .N if self .N != kDynamic else "D"
66
+ K = self .K if self .K != kDynamic else "D"
67
+ name = f"gemm_{ M } _{ N } _{ K } _{ self .operand_element_type } _{ self .accumulator_element_type } "
56
68
if self .tA == "T" :
57
69
name += "_tA"
58
70
elif self .tB == "T" :
59
71
name += "_tB"
72
+ if self .runtime_dim is not None :
73
+ name += f"_D={ self .runtime_dim } "
60
74
return name
61
75
76
+ def get_runtime_dims (self ) -> Tuple [int , int , int ]:
77
+ """
78
+ Get concrete dims to use when executing this kernel.
79
+ """
80
+ M = self .M if self .M != kDynamic else self .runtime_dim
81
+ N = self .N if self .N != kDynamic else self .runtime_dim
82
+ K = self .K if self .K != kDynamic else self .runtime_dim
83
+ return M , N , K
84
+
62
85
def get_inp1 (self ) -> str :
86
+ M , N , K = self .get_runtime_dims ()
63
87
if self .tA == "T" :
64
- return f"{ self . K } x{ self . M } x{ self .operand_element_type } "
65
- return f"{ self . M } x{ self . K } x{ self .operand_element_type } "
88
+ return f"{ K } x{ M } x{ self .operand_element_type } "
89
+ return f"{ M } x{ K } x{ self .operand_element_type } "
66
90
67
91
def get_inp2 (self ) -> str :
92
+ M , N , K = self .get_runtime_dims ()
68
93
if self .tB == "T" :
69
- return f"{ self . N } x{ self . K } x{ self .operand_element_type } "
70
- return f"{ self . K } x{ self . N } x{ self .operand_element_type } "
94
+ return f"{ N } x{ K } x{ self .operand_element_type } "
95
+ return f"{ K } x{ N } x{ self .operand_element_type } "
71
96
72
97
def get_out (self ) -> str :
73
- return f"{ self .M } x{ self .N } x{ self .result_element_type } "
98
+ M , N , K = self .get_runtime_dims ()
99
+ return f"{ M } x{ N } x{ self .result_element_type } "
74
100
75
101
def get_byte_count (self ) -> int :
76
102
operand_bytes_per_element = num_bytes (self .operand_element_type )
77
103
result_bytes_per_element = num_bytes (self .result_element_type )
78
- byte_count_input = (self .M + self .N ) * self .K * operand_bytes_per_element
79
- byte_count_output = (self .M * self .N ) * result_bytes_per_element
104
+ M , N , K = self .get_runtime_dims ()
105
+ byte_count_input = (M + N ) * K * operand_bytes_per_element
106
+ byte_count_output = (M * N ) * result_bytes_per_element
80
107
return byte_count_input + byte_count_output
81
108
82
109
def get_flops (self ) -> int :
83
- flops = 2 * self .M * self .N * self .K
110
+ M , N , K = self .get_runtime_dims ()
111
+ flops = 2 * M * N * K
84
112
return flops
85
113
86
114
@@ -123,16 +151,22 @@ def generate_mlir(config: GemmConfig):
123
151
# Transpose A
124
152
if tA == "T" :
125
153
arg0_type = ir .RankedTensorType .get ([K , M ], operand_element_type )
154
+ arg0_M_idx = 1
126
155
arg1_type = ir .RankedTensorType .get ([K , N ], operand_element_type )
156
+ arg1_N_idx = 1
127
157
# Transpose B
128
158
elif tB == "T" :
129
159
arg0_type = ir .RankedTensorType .get ([M , K ], operand_element_type )
160
+ arg0_M_idx = 0
130
161
arg1_type = ir .RankedTensorType .get ([N , K ], operand_element_type )
162
+ arg1_N_idx = 0
131
163
# "Normal" path (can't transpose both)
132
164
else :
133
165
assert tA == "N" and tB == "N"
134
166
arg0_type = ir .RankedTensorType .get ([M , K ], operand_element_type )
167
+ arg0_M_idx = 0
135
168
arg1_type = ir .RankedTensorType .get ([K , N ], operand_element_type )
169
+ arg1_N_idx = 1
136
170
result_type = ir .RankedTensorType .get ([M , N ], result_element_type )
137
171
138
172
module = ir .Module .create ()
@@ -143,7 +177,24 @@ def main(arg0, arg1):
143
177
zero_element = arith .constant (
144
178
value = literal_zero , result = acc_element_type
145
179
)
146
- empty_tensor = tensor .empty (element_type = acc_element_type , sizes = [M , N ])
180
+ if M == kDynamic :
181
+ M_dynamic_dim_idx = arith .constant (
182
+ value = arg0_M_idx , result = ir .IndexType .get ()
183
+ )
184
+ M_dynamic_dim = tensor .dim (arg0 , M_dynamic_dim_idx )
185
+ if N == kDynamic :
186
+ N_dynamic_dim_idx = arith .constant (
187
+ value = arg1_N_idx , result = ir .IndexType .get ()
188
+ )
189
+ N_dynamic_dim = tensor .dim (arg1 , N_dynamic_dim_idx )
190
+
191
+ empty_tensor = tensor .empty (
192
+ element_type = acc_element_type ,
193
+ sizes = [
194
+ M_dynamic_dim if M == kDynamic else M ,
195
+ N_dynamic_dim if N == kDynamic else N ,
196
+ ],
197
+ )
147
198
filled_tensor = linalg .fill (zero_element , outs = [empty_tensor ])
148
199
149
200
if tA == "T" :
0 commit comments