@@ -58,7 +58,7 @@ def gpu_module():
58
58
set_container_module (ctx .module )
59
59
60
60
v_len = 16
61
- M , K , N = 512 , 512 , 512
61
+ M , K , N = 16 , 16 , 16
62
62
TILE_SIZE = BK = 16
63
63
dtype = T .f16 ()
64
64
np_dtype = np .float16
@@ -78,24 +78,26 @@ def kernel(
78
78
79
79
row = block_idx .y * TILE_SIZE + thread_idx .y
80
80
col = block_idx .x * TILE_SIZE + thread_idx .x
81
+ # gpu.printf("(%ld, %ld)\n", row, col)
82
+ # vector.print_(source=row)
81
83
82
84
sum = arith .constant (np .full ([v_len ], 0.0 , np_dtype ), v16 )
83
85
for t , sum , _ in scf .range_ (0 , N , BK , iter_args = [sum ]):
84
- Bs [thread_idx .y , thread_idx .x ] = B [thread_idx .y + t , col ]
86
+ Bs [thread_idx .y , thread_idx .x ] = B [col , thread_idx .y + t ]
85
87
As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + t ]
86
88
87
89
gpu .barrier ()
88
90
89
- a_frag = As @ vector .load (v16 ) @ [thread_idx .y , 0 ]
90
- b_frag = Bs @ vector .load (v16 ) @ [0 , thread_idx .x ]
91
+ lane = thread_idx .x % v_len
92
+ a_frag = As @ vector .load (v16 ) @ [lane , 0 ]
93
+ b_frag = Bs @ vector .load (v16 ) @ [lane , 0 ]
94
+
95
+ # call the WMMA intrinsic
91
96
false = arith .constant (False , T .bool ())
92
97
sum = rocdl .wmma_f16_16x16x16_f16 (v16 , [a_frag , b_frag , sum , false ])
93
-
94
- gpu .barrier ()
95
-
96
98
sum = yield sum
97
99
98
- C [row , col ] = sum
100
+ C [row , col ] = sum [ 2 * ( row // 2 )]
99
101
100
102
101
103
props = hip .hipDeviceProp_t ()
@@ -110,13 +112,21 @@ def gpu_module():
110
112
111
113
ip .__exit__ (None , None , None )
112
114
115
+ # gpu_module = run_pipeline(gpu_module, Pipeline().cse())
116
+ # print(gpu_module)
117
+
113
118
O = 3
114
119
output_format = "binary"
115
120
116
121
lowered_module = run_pipeline (
117
122
gpu_module ,
118
123
Pipeline ()
119
- .Gpu (Pipeline ().convert_gpu_to_rocdl (use_bare_ptr_memref_call_conv = True ))
124
+ .Gpu (
125
+ Pipeline ().convert_gpu_to_rocdl (
126
+ use_bare_ptr_memref_call_conv = True ,
127
+ runtime = "HIP" ,
128
+ )
129
+ )
120
130
.rocdl_attach_target (chip = arch , abi = "500" , O = O )
121
131
.gpu_to_llvm ()
122
132
.lower_to_llvm ()
@@ -132,12 +142,20 @@ def gpu_module():
132
142
hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
133
143
function = hip_check (hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ()))
134
144
135
- # a_h = np.random.randint(0, 10, (M, K)).astype(dtype=np_dtype)
136
- # b_h = np.random.randint(0, 10, (K, N)).astype(dtype=np_dtype)
137
- a_h = np .ones ((M , K )).astype (dtype = np_dtype )
138
- b_h = np .ones ((K , N )).astype (dtype = np_dtype )
139
- c_h = - 3 * np .ones ((M , N ), dtype = np_dtype )
145
+ a_h = np .random .randint (0 , 10 , (M , K )).astype (dtype = np_dtype )
146
+ b_h = np .random .randint (0 , 10 , (K , N )).astype (dtype = np_dtype )
147
+ # a_h = np.ones((M, K)).astype(dtype=np_dtype)
148
+ # b_h = np.ones((K, N)).astype(dtype=np_dtype)
149
+ c_h = 0 * np .ones ((M , N ), dtype = np_dtype )
150
+
151
+ for k in range (K ):
152
+ a = a_h [:, k ]
153
+ b = b_h [k , :]
154
+ c_h += np .outer (a , b )
155
+
156
+ assert np .allclose (a_h @ b_h , c_h )
140
157
158
+ c_h = - 3 * np .ones ((M , N ), dtype = np_dtype )
141
159
a_num_bytes = a_h .size * a_h .itemsize
142
160
b_num_bytes = b_h .size * b_h .itemsize
143
161
c_num_bytes = c_h .size * c_h .itemsize
@@ -190,14 +208,14 @@ def gpu_module():
190
208
assert not np .allclose (correct , c_h )
191
209
hip_check (hip .hipMemcpy (c_h , c_d , c_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost ))
192
210
193
-
194
211
if not np .allclose (c_h , correct ):
195
212
with np .printoptions (threshold = np .inf , linewidth = np .inf ):
196
- # print("correct", correct)
197
- # print("c_h", c_h)
213
+ print ("correct\n " , correct )
214
+ print ("c_h\n " , c_h )
198
215
print ("off by atol" , np .max (np .abs (correct - c_h )))
199
216
print ("off by rtol" , np .max (np .abs (correct - c_h ) / correct ))
200
217
218
+
201
219
hip_check (hip .hipFree (a_d ))
202
220
hip_check (hip .hipFree (b_d ))
203
221
hip_check (hip .hipFree (c_d ))
0 commit comments