1
+ from pathlib import Path
2
+
1
3
import mlir .extras .types as T
2
4
import numpy as np
3
5
from hip import hip
4
6
from mlir .ir import InsertionPoint , IntegerAttr , UnitAttr
5
-
6
7
from mlir .extras .ast .canonicalize import canonicalize
7
8
from mlir .extras .context import RAIIMLIRContextModule
8
9
from mlir .extras .dialects .ext import memref , scf , arith , gpu , llvm
23
24
# noinspection PyUnresolvedReferences
24
25
from util import hip_check , launch_kernel , hip_synchronize
25
26
27
+
28
+ def init_copy_host_device ():
29
+ q_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
30
+ k_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
31
+ v_h = np .random .randint (0 , 10 , (B * nh * N * d )).astype (dtype = np .float32 )
32
+ l_h = np .zeros ((B * nh * N ), dtype = np .float32 )
33
+ m_h = np .full ((B * nh * N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
34
+ O_h = np .zeros_like (q_h , dtype = np .float32 )
35
+
36
+ host = [q_h , k_h , v_h , l_h , m_h , O_h ]
37
+ device = [hip_check (hip .hipMalloc (h .size * h .itemsize )) for h in host ]
38
+
39
+ for dev , h in zip (device , host ):
40
+ hip_check (
41
+ hip .hipMemcpy (
42
+ dev , h , h .size * h .itemsize , hip .hipMemcpyKind .hipMemcpyHostToDevice
43
+ )
44
+ )
45
+
46
+ return host , device
47
+
48
+
49
+ def copy_device_host (host , device ):
50
+ for d , h in zip (device , host ):
51
+ hip_check (
52
+ hip .hipMemcpy (
53
+ h , d , h .size * h .itemsize , hip .hipMemcpyKind .hipMemcpyDeviceToHost
54
+ )
55
+ )
56
+ hip_check (hip .hipFree (d ))
57
+
58
+ return host
59
+
60
+
26
61
# just so it doesn't get DCE'd by black/reformat
27
62
# TypeError: 'mlir._mlir_libs._mlir.ir.BlockArgument' object is not subscriptable
28
63
_ = memref
@@ -44,25 +79,19 @@ def gpu_module():
44
79
ip = InsertionPoint .at_block_begin (gpu_module .regions [0 ].blocks [0 ])
45
80
ip .__enter__ ()
46
81
47
- batch_size = 16
48
- n_head = 12
49
- seq_len = 64
50
- head_embd = 64
51
-
52
82
Bc = 32
53
83
Br = 32
54
84
55
- B = batch_size
56
- nh = n_head
57
- N = seq_len
58
- d = head_embd
85
+ B = 16
86
+ nh = 12
87
+ N = 128
88
+ d = 128
59
89
60
90
import math
61
91
62
92
Tc = math .ceil (N / Bc )
63
93
Tr = math .ceil (N / Br )
64
94
softmax_scale = 1.0 / math .sqrt (d )
65
- tile_size = Bc * d # size of Qi, Kj, Vj
66
95
67
96
68
97
def softmax (x , axis = None ):
@@ -75,11 +104,11 @@ def manual_attn(q, k, v):
75
104
# the kernel below overwrites the global math.........
76
105
import math
77
106
78
- q = q .reshape (batch_size , n_head , seq_len , head_embd )
79
- k = k .reshape (batch_size , n_head , seq_len , head_embd )
80
- v = v .reshape (batch_size , n_head , seq_len , head_embd )
107
+ q = q .reshape (B , nh , N , d )
108
+ k = k .reshape (B , nh , N , d )
109
+ v = v .reshape (B , nh , N , d )
81
110
82
- att = q @ k .transpose (0 , 1 , - 2 , - 1 ) * (1.0 / math .sqrt (k .shape [- 1 ]))
111
+ att = q @ k .transpose (0 , 1 , 3 , 2 ) * (1.0 / math .sqrt (k .shape [- 1 ]))
83
112
att = softmax (att , axis = - 1 )
84
113
y = att @ v
85
114
return y .flatten ()
@@ -92,40 +121,46 @@ def manual_attn(q, k, v):
92
121
@gpu_func (emit = True )
93
122
@canonicalize (using = [scf .canonicalizer , arith .canonicalizer ])
94
123
def flash_attention (
95
- Q : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
96
- K : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
97
- V : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
124
+ Q : T .memref (B * nh * N * d , T .f32 ()),
125
+ K : T .memref (B * nh * N * d , T .f32 ()),
126
+ V : T .memref (B * nh * N * d , T .f32 ()),
98
127
l : T .memref (B * nh * N , T .f32 ()),
99
128
m : T .memref (B * nh * N , T .f32 ()),
100
- O : T .memref (batch_size * n_head * seq_len * head_embd , T .f32 ()),
129
+ O : T .memref (B * nh * N * d , T .f32 ()),
101
130
):
102
131
tx = thread_idx .x
103
- bx = block_idx .x
104
- by = block_idx .y # batch and head index
132
+ # batch idx, head_idx
133
+ bx , by = block_idx .x , block_idx .y
134
+ # gpu.printf("bx %ld, by %ld\n", bx, by)
105
135
106
136
# Offset into Q,K,V,O,l,m - different for each batch and head
107
- qkv_offset = bx * grid_dim . y * N * d + by * N * d # gridDim.y = nh
108
- lm_offset = bx * grid_dim . y * N + by * N # offset for l and m
137
+ qkv_offset = bx * nh * N * d + by * N * d
138
+ lm_offset = bx * nh * N + by * N # offset for l and m
109
139
110
140
# Define SRAM for Q,K,V,S
111
141
sram = gpu .dynamic_shared_memory ()
112
- Qi = memref .view (sram , (tile_size ,), dtype = T .f32 ())
113
- Kj = memref .view (sram , (tile_size ,), dtype = T .f32 (), shift = tile_size * 1 )
114
- Vj = memref .view (sram , (tile_size ,), dtype = T .f32 (), shift = tile_size * 2 )
115
- S = memref .view (sram , (tile_size ,), dtype = T .f32 (), shift = tile_size * 3 )
142
+ Qi = memref .view (sram , (Br * d ,), dtype = T .f32 ())
143
+ Kj = memref .view (sram , (Bc * d ,), dtype = T .f32 (), shift = Qi .n_elements )
144
+ Vj = memref .view (
145
+ sram , (Bc * d ,), dtype = T .f32 (), shift = Qi .n_elements + Kj .n_elements
146
+ )
147
+ S = memref .view (
148
+ sram ,
149
+ (Br * Bc ,),
150
+ dtype = T .f32 (),
151
+ shift = Qi .n_elements + Kj .n_elements + Vj .n_elements ,
152
+ )
116
153
117
154
for j in scf .range_ (0 , Tc ):
118
155
# Load Kj, Vj to SRAM
119
156
for x in scf .range_ (0 , d ):
120
- Kj [tx * d + x ] = K [qkv_offset + tile_size * j + tx * d + x ]
121
- Vj [tx * d + x ] = V [qkv_offset + tile_size * j + tx * d + x ]
122
-
123
- gpu .barrier () # such that the inner loop can use the correct Kj, Vj
157
+ Kj [tx * d + x ] = K [qkv_offset + Bc * d * j + tx * d + x ]
158
+ Vj [tx * d + x ] = V [qkv_offset + Bc * d * j + tx * d + x ]
124
159
125
160
for i in scf .range_ (0 , Tr ):
126
161
# Load Qi to SRAM, l and m to registers
127
162
for x in scf .range_ (0 , d ):
128
- ii = qkv_offset + tile_size * i + tx * d + x
163
+ ii = qkv_offset + Bc * d * i + tx * d + x
129
164
Qi [tx * d + x ] = Q [ii ]
130
165
131
166
row_m_prev = m [lm_offset + Br * i + tx ]
@@ -172,21 +207,18 @@ def flash_attention(
172
207
pv += S [Bc * tx + y ] * Vj [y * d + x ]
173
208
pv = yield pv
174
209
175
- ii = qkv_offset + tile_size * i + tx * d + x
210
+ ii = qkv_offset + Bc * d * i + tx * d + x
176
211
O [ii ] = div * (c * O [ii ] + math .exp (row_m - row_m_new ) * pv )
177
212
178
- gpu .barrier () # otherwise, thread can use the wrong Kj, Vj in inner loop
179
-
180
213
m [lm_offset + Br * i + tx ] = row_m_new
181
214
l [lm_offset + Br * i + tx ] = row_l_new
182
215
183
- # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
184
- # gpu.barrier() # otherwise, thread can use the wrong Kj, Vj in inner loop
216
+ gpu .barrier ()
185
217
186
218
187
219
ip .__exit__ (None , None , None )
188
220
189
- sram_size = 4 * tile_size * np .float32 ().itemsize
221
+ sram_size = 4 * Bc * d * np .float32 ().itemsize
190
222
191
223
launch_params = {
192
224
flash_attention .__name__ : (
@@ -206,6 +238,9 @@ def flash_attention(
206
238
.rocdl_attach_target (chip = arch , O = 3 , abi = "500" ),
207
239
)
208
240
241
+ # print(simplified_module)
242
+ # exit()
243
+
209
244
lowered_module = run_pipeline (
210
245
simplified_module ,
211
246
Pipeline ()
@@ -216,7 +251,8 @@ def flash_attention(
216
251
)
217
252
)
218
253
.gpu_to_llvm ()
219
- .lower_to_llvm (),
254
+ .lower_to_llvm ()
255
+ .ensure_debug_info_scope_on_llvm_func (emission_kind = "Full" ),
220
256
# .Nested("llvm.func", Pipeline().sroa()),
221
257
)
222
258
@@ -236,68 +272,34 @@ def flash_attention(
236
272
T .index (), np .prod (thread_dims )
237
273
)
238
274
239
- lowered_module = run_pipeline (lowered_module , Pipeline ().gpu_module_to_binary ())
275
+ output_format = "bin"
276
+ # output_format = "llvm"
277
+ # output_format = "isa"
278
+
279
+ lowered_module = run_pipeline (
280
+ lowered_module , Pipeline ().gpu_module_to_binary (format = output_format )
281
+ )
240
282
hsaco = get_compile_object_bytes (lowered_module )
283
+ if output_format in {"isa" , "llvm" , "offloading" }:
284
+ with open (Path (__file__ ).parent / "flashattention.amdgcn" , "wb" ) as f :
285
+ f .write (hsaco )
286
+ exit ()
241
287
242
288
hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
243
289
244
- q_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
245
- dtype = np .float32
246
- )
247
- k_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
248
- dtype = np .float32
249
- )
250
- v_h = np .random .randint (0 , 10 , (batch_size * n_head * seq_len * head_embd )).astype (
251
- dtype = np .float32
252
- )
253
- l_h = np .zeros ((B * nh * N ), dtype = np .float32 )
254
- m_h = np .full ((B * nh * N ), float (np .finfo (np .float32 ).min ), dtype = np .float32 )
255
- O_h = np .zeros_like (q_h , dtype = np .float32 )
256
-
257
- q_num_bytes = q_h .size * q_h .itemsize
258
- k_num_bytes = k_h .size * k_h .itemsize
259
- v_num_bytes = v_h .size * v_h .itemsize
260
- l_num_bytes = l_h .size * l_h .itemsize
261
- m_num_bytes = m_h .size * m_h .itemsize
262
- O_num_bytes = O_h .size * O_h .itemsize
263
-
264
- q_d = hip_check (hip .hipMalloc (q_num_bytes ))
265
- k_d = hip_check (hip .hipMalloc (k_num_bytes ))
266
- v_d = hip_check (hip .hipMalloc (v_num_bytes ))
267
- l_d = hip_check (hip .hipMalloc (l_num_bytes ))
268
- m_d = hip_check (hip .hipMalloc (m_num_bytes ))
269
- O_d = hip_check (hip .hipMalloc (O_num_bytes ))
270
-
271
290
stream = 0
272
291
273
292
times = {
274
293
flash_attention : 0 ,
275
294
}
276
- # random.shuffle(kernels)
277
- runs = 16
295
+ runs = 32
278
296
for kernel in times :
279
297
for i in range (runs ):
280
298
function = hip_check (
281
299
hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ())
282
300
)
283
301
hip_check (hip .hipDeviceSynchronize ())
284
302
285
- for d , h , num_bytes in zip (
286
- [q_d , k_d , v_d , l_d , m_d , O_d ],
287
- [q_h , k_h , v_h , l_h , m_h , O_h ],
288
- [
289
- q_num_bytes ,
290
- k_num_bytes ,
291
- v_num_bytes ,
292
- l_num_bytes ,
293
- m_num_bytes ,
294
- O_num_bytes ,
295
- ],
296
- ):
297
- hip_check (
298
- hip .hipMemcpy (d , h , num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice )
299
- )
300
-
301
303
(
302
304
(
303
305
blocks_per_grid_x ,
@@ -312,6 +314,10 @@ def flash_attention(
312
314
shared_memory ,
313
315
) = launch_params [kernel .__name__ ]
314
316
317
+ host , device = init_copy_host_device ()
318
+ q_h , k_h , v_h , * _ = host
319
+ correct = manual_attn (q_h , k_h , v_h )
320
+
315
321
time_compute = launch_kernel (
316
322
function .as_c_void_p (),
317
323
blocks_per_grid_x ,
@@ -322,36 +328,20 @@ def flash_attention(
322
328
threads_per_block_z ,
323
329
stream ,
324
330
shared_memory ,
325
- q_d ,
326
- k_d ,
327
- v_d ,
328
- l_d ,
329
- m_d ,
330
- O_d ,
331
+ * device ,
331
332
)
332
333
333
- hip_check (
334
- hip .hipMemcpy (
335
- l_h , l_d , l_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
336
- )
337
- )
338
- hip_check (
339
- hip .hipMemcpy (
340
- m_h , m_d , m_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
341
- )
342
- )
343
- hip_check (
344
- hip .hipMemcpy (
345
- O_h , O_d , O_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
346
- )
347
- )
348
- correct = manual_attn (q_h , k_h , v_h )
334
+ * _ , O_h = copy_device_host (host , device )
349
335
if not np .allclose (correct , O_h ):
350
- print ("correct" , correct )
351
- print ("l_h" , l_h )
352
- print ("m_h" , m_h )
353
- print ("output" , O_h )
354
- print (f"{ kernel .__name__ } failed" )
336
+ with np .printoptions (threshold = np .inf , linewidth = np .inf ):
337
+ print (
338
+ "correct - output:\n " ,
339
+ correct .round ().reshape (B , nh , N , d )
340
+ - O_h .round ().reshape (B , nh , N , d ),
341
+ )
342
+ print (f"{ kernel .__name__ } failed\n " )
343
+ else :
344
+ print (f"{ kernel .__name__ } : { time_compute :.03f} ms" )
355
345
356
346
times [kernel ] += time_compute
357
347
0 commit comments