Skip to content

Commit 6f67a97

Browse files
authored
[TK] Add support for ops required for Flash Attention 2 (#385)
Add new ops: - tkl.exp2 (math) - tkl.max (reduce max) - tkl.sum (reduce sum) - tkl.broadcast (broadcast leading dims) - tkl.broadcast_in_dim (broadcast specific dimensions) - tkl.transpose (transpose)
1 parent da57fe3 commit 6f67a97

File tree

11 files changed

+311
-88
lines changed

11 files changed

+311
-88
lines changed

python/shark_turbine/kernel/_support/tracing.py

Lines changed: 75 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,13 @@ def wrapper(f):
264264
### ========================================================================
265265
### Math Operations
266266
### ========================================================================
267+
def handle_exp2(self, op, val):
268+
return self.region_graph.create_proxy(
269+
"call_function",
270+
target=op,
271+
args=(val,),
272+
kwargs={},
273+
)
267274

268275
def handle_vector_constant(
269276
self, op, shape: Tuple[int, ...], dtype, value: int | float
@@ -278,15 +285,82 @@ def handle_vector_constant(
278285
### ========================================================================
279286
### Reduction Operations
280287
### ========================================================================
288+
def handle_vector_max(self, op, vector, axis=None, acc=None):
289+
return self.region_graph.create_proxy(
290+
"call_function",
291+
target=op,
292+
args=(vector, axis, acc),
293+
kwargs={},
294+
)
295+
296+
def handle_vector_sum(self, op, vector, axis=None, acc=None):
297+
return self.region_graph.create_proxy(
298+
"call_function",
299+
target=op,
300+
args=(vector, axis, acc),
301+
kwargs={},
302+
)
281303

282-
def handle_vector_dot(self, op, lhs, rhs, acc):
304+
def handle_vector_dot(self, op, lhs, rhs, acc=None):
283305
return self.region_graph.create_proxy(
284306
"call_function",
285307
target=op,
286308
args=(lhs, rhs, acc),
287309
kwargs={},
288310
)
289311

312+
### ========================================================================
313+
### Shape Manipulation Operations
314+
### ========================================================================
315+
def handle_vector_broadcast(self, op, vector, leading_sizes):
316+
return self.region_graph.create_proxy(
317+
"call_function",
318+
target=op,
319+
args=(vector, leading_sizes),
320+
kwargs={},
321+
)
322+
323+
def handle_vector_broadcast_in_dim(self, op, vector, shape, broadcast_dimensions):
324+
# Currently, we do not have a corressponding op in MLIR, so
325+
# we trace this to broadcast + transpose.
326+
# TODO: Add a vector dialect op for this in MLIR.
327+
328+
# Remove broadcast_dimensions from shape.
329+
shape_with_leading = tuple(
330+
dim for i, dim in enumerate(shape) if i not in broadcast_dimensions
331+
)
332+
333+
# Broadcast
334+
broadcasted_vector = self.region_graph.create_proxy(
335+
"call_function",
336+
target=ops.vector_broadcast,
337+
args=(vector, shape_with_leading),
338+
kwargs={},
339+
)
340+
341+
# Get the permutation for the transpose.
342+
permutation = tuple(
343+
i for i in range(len(shape)) if i not in broadcast_dimensions
344+
)
345+
permutation = permutation + tuple(broadcast_dimensions)
346+
print(permutation)
347+
348+
# Transpose
349+
return self.region_graph.create_proxy(
350+
"call_function",
351+
target=ops.vector_transpose,
352+
args=(broadcasted_vector, permutation),
353+
kwargs={},
354+
)
355+
356+
def handle_vector_transpose(self, op, vector, permutation):
357+
return self.region_graph.create_proxy(
358+
"call_function",
359+
target=op,
360+
args=(vector, permutation),
361+
kwargs={},
362+
)
363+
290364

291365
###############################################################################
292366
# Launch context

python/shark_turbine/kernel/compiler/builder.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
Value,
2525
VectorType,
2626
arith_d,
27+
math_d,
2728
builtin_d,
2829
)
2930

@@ -139,7 +140,7 @@ def binary_arithmetic(
139140

140141
def binary_vector_arithmetic(
141142
self, op: str, lhs: IRProxyValue, rhs: IRProxyValue
142-
) -> Value:
143+
) -> IRProxyValue:
143144
lhs_ir = lhs.ir_value
144145
rhs_ir = rhs.ir_value
145146
lhs_element_type = VectorType(lhs_ir.type).element_type
@@ -149,10 +150,33 @@ def binary_vector_arithmetic(
149150
handler = getattr(self, attr_name)
150151
except AttributeError:
151152
raise CodegenError(
152-
f"Cannot perform binary arithmetic operation '{op}' between {lhs.type} and {rhs.type} (tried '{attr_name}')"
153+
f"Cannot perform binary arithmetic operation '{op}' between {lhs_ir.type} and {rhs_ir.type} (tried '{attr_name}')"
153154
)
154155
return handler(lhs, rhs)
155156

157+
def unary_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
158+
val_ir_type = val.ir_value.type
159+
attr_name = f"unary_{op}_{val_ir_type}"
160+
try:
161+
handler = getattr(self, attr_name)
162+
except AttributeError:
163+
raise CodegenError(
164+
f"Cannot perform unary arithmetic operation '{op}' on {val_ir_type} (tried '{attr_name}')"
165+
)
166+
return handler(val)
167+
168+
def unary_vector_arithmetic(self, op: str, val: IRProxyValue) -> IRProxyValue:
169+
val_ir = val.ir_value
170+
val_element_type = VectorType(val_ir.type).element_type
171+
attr_name = f"unary_{op}_{val_element_type}"
172+
try:
173+
handler = getattr(self, attr_name)
174+
except AttributeError:
175+
raise CodegenError(
176+
f"Cannot perform unary arithmetic operation '{op}' on {val_ir.type} (tried '{attr_name}')"
177+
)
178+
return handler(val)
179+
156180
def promote_index_to_f32(self, value: Value, to_type: IrType) -> Value:
157181
i32_type = IntegerType.get_signless(32)
158182
i32 = arith_d.index_cast(i32_type, value)
@@ -215,5 +239,8 @@ def binary_truediv_f32_f32(
215239
) -> IRProxyValue:
216240
return IRProxyValue(arith_d.divf(lhs.ir_value, rhs.ir_value))
217241

242+
def unary_exp2_f32(self, val: IRProxyValue) -> IRProxyValue:
243+
return IRProxyValue(math_d.exp2(val.ir_value))
244+
218245

219246
ScalarBuilder = _ScalarBuilder()

0 commit comments

Comments
 (0)