@@ -58,7 +58,7 @@ def gpu_module():
58
58
set_container_module (ctx .module )
59
59
60
60
v_len = 16
61
- M , K , N = 16 , 16 , 16
61
+ M , K , N = 512 , 512 , 512
62
62
TILE_SIZE = BK = 16
63
63
dtype = T .f16 ()
64
64
np_dtype = np .float16
@@ -78,23 +78,27 @@ def kernel(
78
78
79
79
row = block_idx .y * TILE_SIZE + thread_idx .y
80
80
col = block_idx .x * TILE_SIZE + thread_idx .x
81
+ lane = thread_idx .x % v_len
81
82
# gpu.printf("(%ld, %ld)\n", row, col)
82
83
# vector.print_(source=row)
83
84
84
85
sum = arith .constant (np .full ([v_len ], 0.0 , np_dtype ), v16 )
85
- for t , sum , _ in scf .range_ (0 , N , BK , iter_args = [sum ]):
86
- Bs [thread_idx .y , thread_idx .x ] = B [col , thread_idx .y + t ]
87
- As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + t ]
88
86
87
+ Bs [thread_idx .y , thread_idx .x ] = B [col , thread_idx .y + 0 ]
88
+ As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + 0 ]
89
+
90
+ for t , sum , _ in scf .range_ (BK , N + BK , BK , iter_args = [sum ]):
89
91
gpu .barrier ()
90
92
91
- lane = thread_idx .x % v_len
92
93
a_frag = As @ vector .load (v16 ) @ [lane , 0 ]
93
94
b_frag = Bs @ vector .load (v16 ) @ [lane , 0 ]
94
95
95
- # call the WMMA intrinsic
96
- false = arith .constant (False , T .bool ())
97
- sum = rocdl .wmma_f16_16x16x16_f16 (v16 , [a_frag , b_frag , sum , false ])
96
+ sum = rocdl .wmma_f16_16x16x16_f16 (a_frag , b_frag , sum )
97
+
98
+ if arith .index_cast (t , T .i32 ()) < N :
99
+ Bs [thread_idx .y , thread_idx .x ] = B [col , thread_idx .y + t ]
100
+ As [thread_idx .y , thread_idx .x ] = A [row , thread_idx .x + t ]
101
+
98
102
sum = yield sum
99
103
100
104
C [row , col ] = sum [2 * (row // 2 )]
@@ -142,18 +146,25 @@ def gpu_module():
142
146
hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
143
147
function = hip_check (hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ()))
144
148
145
- a_h = np .random .randint (0 , 10 , (M , K )).astype (dtype = np_dtype )
146
- b_h = np .random .randint (0 , 10 , (K , N )).astype (dtype = np_dtype )
147
- # a_h = np.ones((M, K)).astype(dtype=np_dtype)
148
- # b_h = np.ones((K, N)).astype(dtype=np_dtype)
149
- c_h = 0 * np .ones ((M , N ), dtype = np_dtype )
149
+ # a_h = np.random.randint(1, 5, (M, K)).astype(dtype=np_dtype)
150
+ # b_h = np.random.randint(1, 5, (K, N)).astype(dtype=np_dtype)
150
151
152
+ # a_h = np.random.rand(M, K).astype(np_dtype)
153
+ # b_h = np.random.rand(K, N).astype(np_dtype)
154
+
155
+ a_h = 3 * np .ones ((M , K )).astype (dtype = np_dtype )
156
+ a_h [0 : M // 2 , 0 : K // 2 ] = 0
157
+ a_h [M // 2 : M , K // 2 : K ] = 1
158
+ b_h = 2 * np .ones ((K , N )).astype (dtype = np_dtype )
159
+ b_h [0 : K // 2 , 0 : N // 2 ] = 2
160
+ b_h [K // 2 : K , N // 2 : N ] = 3
161
+
162
+ c_h = 0 * np .ones ((M , N ), dtype = np .float32 )
151
163
for k in range (K ):
152
- a = a_h [:, k ]
153
- b = b_h [k , :]
164
+ a = a_h . astype ( np . float32 ) [:, k ]
165
+ b = b_h . astype ( np . float32 ) [k , :]
154
166
c_h += np .outer (a , b )
155
-
156
- assert np .allclose (a_h @ b_h , c_h )
167
+ assert np .allclose (a_h .astype (np .float32 ) @ b_h .astype (np .float32 ), c_h )
157
168
158
169
c_h = - 3 * np .ones ((M , N ), dtype = np_dtype )
159
170
a_num_bytes = a_h .size * a_h .itemsize
@@ -210,10 +221,12 @@ def gpu_module():
210
221
211
222
if not np .allclose (c_h , correct ):
212
223
with np .printoptions (threshold = np .inf , linewidth = np .inf ):
213
- print ("correct\n " , correct )
214
- print ("c_h\n " , c_h )
224
+ # print("correct\n", correct)
225
+ # print("c_h\n", c_h)
215
226
print ("off by atol" , np .max (np .abs (correct - c_h )))
216
227
print ("off by rtol" , np .max (np .abs (correct - c_h ) / correct ))
228
+ print ("num incorrect" , np .sum (np .abs (correct - c_h ) != 0 ))
229
+ print ("fraction incorrect" , np .sum (np .abs (correct - c_h ) != 0 ) / (M * N ))
217
230
218
231
219
232
hip_check (hip .hipFree (a_d ))
0 commit comments