Skip to content

Commit 3eb5e66

Browse files
committed
add liveness analysis demo
1 parent 62a1ecc commit 3eb5e66

File tree

7 files changed

+677
-17
lines changed

7 files changed

+677
-17
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ jobs:
7575
7676
python examples/mwe.py
7777
python examples/flash_attention.py
78+
python examples/liveness_analysis.py
7879
7980
test-other-host-bindings:
8081

examples/liveness_analysis.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
from mlir import ir
2+
from pathlib import Path
3+
4+
import mlir.extras.types as T
5+
import numpy as np
6+
from mlir.ir import InsertionPoint, IntegerAttr, UnitAttr
7+
8+
from mlir.extras.ast.canonicalize import canonicalize
9+
from mlir.extras.context import RAIIMLIRContextModule
10+
from mlir.extras.dialects.ext import memref, scf, arith, gpu, llvm
11+
from mlir.dialects import math
12+
13+
# noinspection PyUnresolvedReferences
14+
from mlir.extras.dialects.ext.gpu import (
15+
block_idx,
16+
thread_idx,
17+
grid_dim,
18+
func as gpu_func,
19+
set_container_module,
20+
module,
21+
get_compile_object_bytes,
22+
)
23+
from mlir.extras.runtime.passes import run_pipeline, Pipeline
24+
from mlir.extras.util import find_ops, walk_blocks_in_operation, walk_operations
25+
from mlir.extras.util.liveness import (
26+
BlockInfoBuilder,
27+
Liveness,
28+
LiveInterval,
29+
linear_scan_register_allocation,
30+
)
31+
32+
# just so it doesn't get DCE'd by black/reformat
33+
# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
34+
_ = memref
35+
36+
ctx = RAIIMLIRContextModule()
37+
set_container_module(ctx.module)
38+
39+
40+
# just a default attr - actual target is set blow
41+
@module("kernels", [f'#rocdl.target<abi = "500">'])
42+
def gpu_module():
43+
pass
44+
45+
46+
ip = InsertionPoint.at_block_begin(gpu_module.regions[0].blocks[0])
47+
ip.__enter__()
48+
49+
Bc = 32
50+
Br = 32
51+
52+
B = 16
53+
nh = 12
54+
N = 128
55+
d = 128
56+
57+
softmax_scale = 1.0 / float(np.sqrt(d))
58+
59+
60+
rank_reduce = memref.rank_reduce
61+
62+
63+
# https://github.com/tspeterkim/flash-attention-minimal/blob/main/flash.cu
64+
@gpu_func(emit=True)
65+
@canonicalize(using=[scf.canonicalizer, arith.canonicalizer])
66+
def flash_attention(
67+
Q: T.memref(B, nh, N, d, T.f32()),
68+
K: T.memref(B, nh, N, d, T.f32()),
69+
V: T.memref(B, nh, N, d, T.f32()),
70+
l: T.memref(B, nh, N, T.f32()),
71+
m: T.memref(B, nh, N, T.f32()),
72+
O: T.memref(B, nh, N, d, T.f32()),
73+
):
74+
tx = thread_idx.x
75+
# batch idx, head_idx
76+
bx, by = block_idx.x, block_idx.y
77+
# gpu.printf("bx %ld, by %ld\n", bx, by)
78+
79+
# Offset into Q,K,V,O,l,m - different for each batch and head
80+
K = K[bx, by, :, :, rank_reduce]
81+
V = V[bx, by, :, :, rank_reduce]
82+
Q = Q[bx, by, :, :, rank_reduce]
83+
O = O[bx, by, :, :, rank_reduce]
84+
l = l[bx, by, :, rank_reduce]
85+
m = m[bx, by, :, rank_reduce]
86+
87+
# Define SRAM for Q,K,V,S
88+
sram = gpu.dynamic_shared_memory()
89+
Qi = memref.view(sram, (Br, d), dtype=T.f32())
90+
Kj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements)
91+
Vj = memref.view(sram, (Bc, d), dtype=T.f32(), shift=Qi.n_elements + Kj.n_elements)
92+
S = memref.view(
93+
sram,
94+
(Br, Bc),
95+
dtype=T.f32(),
96+
shift=Qi.n_elements + Kj.n_elements + Vj.n_elements,
97+
)
98+
99+
for bc in scf.range_(0, N, Bc):
100+
# Load Kj, Vj to SRAM
101+
K_ = K[bc : bc + 1, :]
102+
V_ = V[bc : bc + 1, :]
103+
for x in scf.range_(0, d):
104+
Kj[tx, x] = K_[tx, x]
105+
Vj[tx, x] = V_[tx, x]
106+
107+
for br in scf.range_(0, N, Br):
108+
# Load Qi to SRAM, l and m to registers
109+
Q_ = Q[br : br + 1, :]
110+
for x in scf.range_(0, d):
111+
Qi[tx, x] = Q_[tx, x]
112+
113+
l_ = l[br : br + 1]
114+
m_ = m[br : br + 1]
115+
row_l_prev = l_[tx]
116+
row_m_prev = m_[tx]
117+
118+
# S = QK^T, row_m = rowmax(S)
119+
row_m: T.f32() = float(np.finfo(np.float32).min)
120+
for y, row_m, _ in scf.range_(0, Bc, iter_args=[row_m]):
121+
sum: T.f32() = 0.0
122+
for x, sum, _ in scf.range_(0, d, iter_args=[sum]):
123+
sum += Qi[tx, x] * Kj[y, x]
124+
sum = yield sum
125+
126+
sum *= softmax_scale
127+
S[tx, y] = sum
128+
129+
if sum > row_m:
130+
row_m_ = yield sum
131+
else:
132+
row_m_ = yield row_m
133+
134+
row_m = yield row_m_
135+
136+
# P = exp(S - row_m), row_l = rowsum(P)
137+
row_l: T.f32() = 0.0
138+
for y, row_l, _ in scf.range_(0, Bc, iter_args=[row_l]):
139+
S[tx, y] = math.exp(S[tx, y] - row_m)
140+
row_l += S[tx, y]
141+
row_l = yield row_l
142+
143+
# Compute new m and l
144+
row_m_new = arith.maximumf(row_m_prev, row_m)
145+
row_l_new = (
146+
math.exp(row_m_prev - row_m_new) * row_l_prev
147+
+ math.exp(row_m - row_m_new) * row_l
148+
)
149+
div = 1.0 / row_l_new
150+
f1 = row_l_prev * math.exp(row_m_prev - row_m_new)
151+
f2 = math.exp(row_m - row_m_new)
152+
153+
# Write O, l, m to HBM
154+
O_ = O[br : br + 1, :]
155+
for x in scf.range_(0, d):
156+
pv: T.f32() = 0.0 # Pij * Vj
157+
for y, pv, _ in scf.range_(0, Bc, iter_args=[pv]):
158+
pv += S[tx, y] * Vj[y, x]
159+
pv = yield pv
160+
161+
O_[tx, x] = div * (f1 * O_[tx, x] + f2 * pv)
162+
163+
l_[tx] = row_l_new
164+
m_[tx] = row_m_new
165+
166+
gpu.barrier()
167+
168+
169+
ip.__exit__(None, None, None)
170+
171+
assert gpu_module.operation.verify()
172+
# l = Liveness(gpu_module)
173+
# print(l)
174+
175+
176+
# https://langdev.stackexchange.com/questions/4325/how-do-modern-compilers-choose-which-variables-to-put-in-registers
177+
x = LiveInterval(1, 3, "x")
178+
t1 = LiveInterval(1, 2, "t1")
179+
y = LiveInterval(2, 5, "y")
180+
z = LiveInterval(3, 4, "z")
181+
t2 = LiveInterval(4, 5, "t2")
182+
y2 = LiveInterval(5, 6, "y2")
183+
184+
register, location = linear_scan_register_allocation([x, t1, y, z, t2, y2], 2)
185+
186+
for v, r in register.items():
187+
print(v, r)
188+
for v, l in location.items():
189+
print(v, l)

mlir/extras/util/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .util import *
2+
from .util import (
3+
_get_previous_frame_idents,
4+
_get_sym_name,
5+
_update_caller_vars,
6+
_unpack_sizes_element_type,
7+
)

0 commit comments

Comments
 (0)