1
1
#!/usr/bin/env python
2
+ from pathlib import Path
2
3
3
4
import mlir .extras .types as T
4
5
import numpy as np
7
8
8
9
from mlir .extras .ast .canonicalize import canonicalize
9
10
from mlir .extras .context import RAIIMLIRContextModule
10
- from mlir .extras .dialects .ext import memref , scf , arith , rocdl
11
+ from mlir .extras .dialects .ext import memref , scf , arith , rocdl , gpu , llvm , vector
11
12
12
13
# noinspection PyUnresolvedReferences
13
14
from mlir .extras .dialects .ext .gpu import (
25
26
module ,
26
27
get_compile_object_bytes ,
27
28
lds_space ,
29
+ dynamic_shared_memory ,
28
30
)
29
31
from mlir .extras .runtime .passes import run_pipeline , Pipeline
30
32
@@ -43,10 +45,6 @@ def time_to_gflops(time_ms, N):
43
45
ctx = RAIIMLIRContextModule ()
44
46
set_container_module (ctx .module )
45
47
46
- props = hip .hipDeviceProp_t ()
47
- hip_check (hip .hipGetDeviceProperties (props , 0 ))
48
- arch = props .gcnArchName .decode ()
49
-
50
48
51
49
# just a default attr - actual target is set blow
52
50
@module ("kernels" , [f'#rocdl.target<abi = "500">' ])
@@ -60,40 +58,44 @@ def gpu_module():
60
58
set_container_module (ctx .module )
61
59
62
60
v_len = 16
63
- M , K , N = 1024 , 1024 , 1024
64
- v16f16 = T .vector (v_len , T .f16 ())
61
+ M , K , N = 512 , 512 , 512
62
+ TILE_SIZE = BK = 16
63
+ dtype = T .f16 ()
64
+ np_dtype = np .float16
65
+ v16 = T .vector (v_len , dtype )
65
66
66
67
67
68
@gpu_func
68
69
@canonicalize (using = scf .canonicalizer )
69
- def smol_matmul (
70
- a : T .memref (M , K , T .f16 ()),
71
- b : T .memref (K , N , T .f16 ()),
72
- c : T .memref (M , N , T .f16 ()),
70
+ def kernel (
71
+ A : T .memref (M , K , dtype ), B : T .memref (K , N , dtype ), C : T .memref (M , N , dtype )
73
72
):
74
- lIdx = thread_idx .x
75
- # a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
76
- # a_frag will store one column of the 16x16 matrix A tile
77
- # b_frag will store one row of the 16x16 matrix B tile
78
- a_frag = arith .constant (np .full ([v_len ], 0.0 , np .float16 ), v16f16 )
79
- b_frag = arith .constant (np .full ([v_len ], 0.0 , np .float16 ), v16f16 )
80
- c_frag = arith .constant (np .full ([v_len ], 0.0 , np .float16 ), v16f16 )
81
-
82
- # lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA 3
83
- lane = lIdx % v_len
84
- for ele in range (v_len ):
85
- b_frag [ele ] = b [ele , lane ]
86
- a_frag [ele ] = a [lane , ele ]
87
- # a_frag, b_frag = yield a_frag, b_frag
88
-
89
- # call the WMMA intrinsic
90
- false = arith .constant (False , T .bool ())
91
- c_frag = rocdl .wmma_f16_16x16x16_f16 (v16f16 , [a_frag , b_frag , c_frag , false ])
92
-
93
- for ele in range (v_len // 2 ):
94
- r = ele * 2 + (lIdx // v_len )
95
- # store results from unpacked c_frag output
96
- c [r , lane ] = c_frag [ele * 2 ]
73
+ base = dynamic_shared_memory ()
74
+ As = memref .view (base , (TILE_SIZE , TILE_SIZE ), dtype = dtype )
75
+ Bs = memref .view (
76
+ base , (TILE_SIZE , TILE_SIZE ), dtype = dtype , shift = TILE_SIZE * TILE_SIZE
77
+ )
78
+
79
+ row = block_idx .y * TILE_SIZE + thread_idx .y
80
+ col = block_idx .x * TILE_SIZE + thread_idx .x
81
+
82
+ sum = arith .constant (np .full ([v_len ], 0.0 , np_dtype ), v16 )
83
+ for t , sum , _ in scf .range_ (0 , N , BK , iter_args = [sum ]):
84
+ Bs [thread_idx .y , thread_idx .x ] = B [thread_idx .y + t , col ]
85
+ As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + t ]
86
+
87
+ gpu .barrier ()
88
+
89
+ a_frag = As @ vector .load (v16 ) @ [thread_idx .y , 0 ]
90
+ b_frag = Bs @ vector .load (v16 ) @ [0 , thread_idx .x ]
91
+ false = arith .constant (False , T .bool ())
92
+ sum = rocdl .wmma_f16_16x16x16_f16 (v16 , [a_frag , b_frag , sum , false ])
93
+
94
+ gpu .barrier ()
95
+
96
+ sum = yield sum
97
+
98
+ C [row , col ] = sum
97
99
98
100
99
101
props = hip .hipDeviceProp_t ()
@@ -103,31 +105,38 @@ def smol_matmul(
103
105
104
106
@module ("naive" , [f'#rocdl.target<chip = "{ arch } ", abi = "500">' ])
105
107
def gpu_module ():
106
- smol_matmul .emit ()
108
+ kernel .emit ()
107
109
108
110
109
111
ip .__exit__ (None , None , None )
110
112
113
+ O = 3
114
+ output_format = "binary"
115
+
111
116
lowered_module = run_pipeline (
112
117
gpu_module ,
113
118
Pipeline ()
114
119
.Gpu (Pipeline ().convert_gpu_to_rocdl (use_bare_ptr_memref_call_conv = True ))
115
- .rocdl_attach_target (chip = arch , abi = "500" , O = 0 )
120
+ .rocdl_attach_target (chip = arch , abi = "500" , O = O )
116
121
.gpu_to_llvm ()
117
122
.lower_to_llvm ()
118
123
.ensure_debug_info_scope_on_llvm_func (emission_kind = "Full" )
119
- .gpu_module_to_binary (),
124
+ .gpu_module_to_binary (format = output_format ),
120
125
)
121
126
122
127
hsaco = get_compile_object_bytes (lowered_module )
128
+ if output_format == "assembly" :
129
+ with open (Path (__file__ ).parent / f"hsacoO{ O } .txt" , "wb" ) as f :
130
+ f .write (hsaco )
131
+ exit ()
123
132
hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
124
- function = hip_check (
125
- hip .hipModuleGetFunction (hip_module , smol_matmul .__name__ .encode ())
126
- )
133
+ function = hip_check (hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ()))
127
134
128
- a_h = np .random .randint (0 , 10 , (M , K )).astype (dtype = np .float16 )
129
- b_h = np .random .randint (0 , 10 , (K , N )).astype (dtype = np .float16 )
130
- c_h = - 3 * np .ones ((M , N ), dtype = np .float16 )
135
+ # a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype)
136
+ # b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype)
137
+ a_h = np .ones ((M , K )).astype (dtype = np_dtype )
138
+ b_h = np .ones ((K , N )).astype (dtype = np_dtype )
139
+ c_h = - 3 * np .ones ((M , N ), dtype = np_dtype )
131
140
132
141
a_num_bytes = a_h .size * a_h .itemsize
133
142
b_num_bytes = b_h .size * b_h .itemsize
@@ -141,22 +150,34 @@ def gpu_module():
141
150
hip_check (hip .hipMemcpy (b_d , b_h , b_num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice ))
142
151
hip_check (hip .hipMemcpy (c_d , c_h , c_num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice ))
143
152
144
- gridX = 32
145
- gridY = 32
146
- gridZ = 1
147
- warp_size = 32
148
- num_warps = 1
153
+ (
154
+ (
155
+ blocks_per_grid_x ,
156
+ blocks_per_grid_y ,
157
+ blocks_per_grid_z ,
158
+ ),
159
+ (
160
+ threads_per_block_x ,
161
+ threads_per_block_y ,
162
+ threads_per_block_z ,
163
+ ),
164
+ shared_memory ,
165
+ ) = (
166
+ (N // TILE_SIZE , N // TILE_SIZE , 1 ),
167
+ (TILE_SIZE , TILE_SIZE , 1 ),
168
+ 2 * TILE_SIZE * TILE_SIZE * dtype .width // 8 ,
169
+ )
170
+
149
171
stream = 0
150
- shared_memory = 0
151
172
152
173
launch_kernel (
153
174
function .as_c_void_p (),
154
- gridX ,
155
- gridY ,
156
- gridZ ,
157
- warp_size ,
158
- num_warps ,
159
- 1 ,
175
+ blocks_per_grid_x ,
176
+ blocks_per_grid_y ,
177
+ blocks_per_grid_z ,
178
+ threads_per_block_x ,
179
+ threads_per_block_y ,
180
+ threads_per_block_z ,
160
181
stream ,
161
182
shared_memory ,
162
183
a_d ,
@@ -169,11 +190,13 @@ def gpu_module():
169
190
assert not np .allclose (correct , c_h )
170
191
hip_check (hip .hipMemcpy (c_h , c_d , c_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost ))
171
192
172
- # if not np.allclose(c_h, correct):
173
- # with np.printoptions(threshold=np.inf, linewidth=200):
174
- # print(correct)
175
- # print(c_h)
176
- # assert False
193
+
194
+ if not np .allclose (c_h , correct ):
195
+ with np .printoptions (threshold = np .inf , linewidth = np .inf ):
196
+ # print("correct", correct)
197
+ # print("c_h", c_h)
198
+ print ("off by atol" , np .max (np .abs (correct - c_h )))
199
+ print ("off by rtol" , np .max (np .abs (correct - c_h ) / correct ))
177
200
178
201
hip_check (hip .hipFree (a_d ))
179
202
hip_check (hip .hipFree (b_d ))
0 commit comments