1
- import ctypes
2
1
import platform
2
+ import random
3
3
import sys
4
4
import tempfile
5
+ import time
5
6
from textwrap import dedent
6
7
7
8
import mlir .extras .types as T
40
41
41
42
# noinspection PyUnresolvedReferences
42
43
from mlir .extras .testing import mlir_ctx as ctx , filecheck , MLIRContext
43
- from util import hip_bindings_not_installed , hip_check , launch_kernel
44
+ from util import hip_bindings_not_installed , hip_check , launch_kernel , hip_synchronize
44
45
45
46
# needed since the fix isn't defined here nor conftest.py
46
47
pytest .mark .usefixtures ("ctx" )
@@ -962,6 +963,7 @@ def test_amdgpu_vector(ctx: MLIRContext):
962
963
963
964
scale = 2
964
965
M , K , N = 2 * scale , 4 * scale , 6 * scale
966
+ tz_a , tz_b , tz_c = [2 , 2 , 2 ]
965
967
v2f32 = T .vector (2 , T .f32 ())
966
968
967
969
@gpu_func
@@ -972,11 +974,11 @@ def smol_matmul(
972
974
):
973
975
cst = arith .constant (np .full ([4 ], 0.0 , np .float32 ), T .vector (4 , T .f32 ()))
974
976
cst_0 = arith .constant (
975
- np .full ([2 , 2 ], 0.0 , np .float32 ), T .vector (2 , 2 , T .f32 ())
977
+ np .full ([tz_a , tz_b ], 0.0 , np .float32 ), T .vector (tz_a , tz_b , T .f32 ())
976
978
)
977
- for i , C , v0 in scf .range_ (0 , M , 2 , iter_args = [C ]):
978
- for j , C , v1 in scf .range_ (0 , N , 2 , iter_args = [C ]):
979
- for k , C , v2 in scf .range_ (0 , K , 2 , iter_args = [C ]):
979
+ for i , C , v0 in scf .range_ (0 , M , tz_a , iter_args = [C ]):
980
+ for j , C , v1 in scf .range_ (0 , N , tz_b , iter_args = [C ]):
981
+ for k , C , v2 in scf .range_ (0 , K , tz_c , iter_args = [C ]):
980
982
cst [0 ::1 ] = A @ load (v2f32 ) @ [i , k ]
981
983
cst [2 ::1 ] = A @ load (v2f32 ) @ [i + 1 , k ]
982
984
cst_0 [0 ] = C @ load (v2f32 ) @ [i , j ]
@@ -1078,3 +1080,116 @@ def gpu_module():
1078
1080
hip_check (hip .hipFree (c_d ))
1079
1081
1080
1082
hip_check (hip .hipModuleUnload (hip_module ))
1083
+
1084
+
1085
+ @pytest .mark .skipif (hip_bindings_not_installed (), reason = "hip not installed" )
1086
+ def test_amdgpu_bank_conflicts (ctx : MLIRContext ):
1087
+ from hip import hip
1088
+
1089
+ set_container_module (ctx .module )
1090
+
1091
+ M = 1024
1092
+
1093
+ @gpu_func
1094
+ def no_bank_conflicts (A : T .memref (M , M , T .f32 ()), B : T .memref (M , M , T .f32 ())):
1095
+ for i in range (M ):
1096
+ a = A [i , thread_idx .x ]
1097
+ B [i , thread_idx .x ] = a * a
1098
+
1099
+ @gpu_func
1100
+ def all_bank_conflicts (A : T .memref (M , M , T .f32 ()), B : T .memref (M , M , T .f32 ())):
1101
+ for i in range (M ):
1102
+ a = A [i , thread_idx .x ]
1103
+ B [thread_idx .x , i ] = a * a
1104
+
1105
+ props = hip .hipDeviceProp_t ()
1106
+ hip_check (hip .hipGetDeviceProperties (props , 0 ))
1107
+ arch = props .gcnArchName .decode ()
1108
+
1109
+ @module ("naive" , [f'#rocdl.target<chip = "{ arch } ">' ])
1110
+ def gpu_module ():
1111
+ no_bank_conflicts .emit ()
1112
+ all_bank_conflicts .emit ()
1113
+
1114
+ lowered_module = run_pipeline (
1115
+ gpu_module ,
1116
+ Pipeline ()
1117
+ .Gpu (Pipeline ().convert_gpu_to_rocdl (use_bare_ptr_memref_call_conv = True ))
1118
+ .rocdl_attach_target (chip = arch )
1119
+ .gpu_to_llvm ()
1120
+ .lower_to_llvm ()
1121
+ .gpu_module_to_binary (),
1122
+ )
1123
+
1124
+ hsaco = get_compile_object_bytes (lowered_module )
1125
+ hip_module = hip_check (hip .hipModuleLoadData (hsaco ))
1126
+
1127
+ a_h = np .arange (M ).astype (dtype = np .float32 )
1128
+ a_h = np .tile (a_h , (M , 1 ))
1129
+ b_h = np .zeros ((M , M ), dtype = np .float32 )
1130
+
1131
+ a_num_bytes = a_h .size * a_h .itemsize
1132
+ b_num_bytes = b_h .size * b_h .itemsize
1133
+
1134
+ a_d = hip_check (hip .hipMalloc (a_num_bytes ))
1135
+ b_d = hip_check (hip .hipMalloc (b_num_bytes ))
1136
+
1137
+ gridX = max (M // 32 , 1 )
1138
+ gridY = max (M // 8 , 1 )
1139
+ gridZ = 1
1140
+ warp_size = 32
1141
+ num_warps = 8
1142
+ stream = 0
1143
+ shared_memory = 0
1144
+
1145
+ times = {
1146
+ no_bank_conflicts .__name__ : 0 ,
1147
+ all_bank_conflicts .__name__ : 0 ,
1148
+ }
1149
+ runs = 10
1150
+ start , stop = hip .hipEventCreate (), hip .hipEventCreate ()
1151
+ for i in range (runs ):
1152
+ kernels = [no_bank_conflicts , all_bank_conflicts ]
1153
+ random .shuffle (kernels )
1154
+ for kernel in kernels :
1155
+ hip_check (
1156
+ hip .hipMemcpy (
1157
+ a_d , a_h , a_num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice
1158
+ )
1159
+ )
1160
+ hip_check (
1161
+ hip .hipMemcpy (
1162
+ b_d , b_h , b_num_bytes , hip .hipMemcpyKind .hipMemcpyHostToDevice
1163
+ )
1164
+ )
1165
+ function = hip_check (
1166
+ hip .hipModuleGetFunction (hip_module , kernel .__name__ .encode ())
1167
+ )
1168
+
1169
+ start = time .monotonic ()
1170
+ launch_kernel (
1171
+ function .as_c_void_p (),
1172
+ gridX ,
1173
+ gridY ,
1174
+ gridZ ,
1175
+ warp_size ,
1176
+ num_warps ,
1177
+ stream ,
1178
+ shared_memory ,
1179
+ a_d ,
1180
+ b_d ,
1181
+ )
1182
+ hip_synchronize ()
1183
+ if i > 0 :
1184
+ times [kernel .__name__ ] += time .monotonic () - start
1185
+
1186
+ hip_check (
1187
+ hip .hipMemcpy (
1188
+ b_h , b_d , b_num_bytes , hip .hipMemcpyKind .hipMemcpyDeviceToHost
1189
+ )
1190
+ )
1191
+
1192
+ times [no_bank_conflicts .__name__ ] /= runs
1193
+ times [all_bank_conflicts .__name__ ] /= runs
1194
+ for k , v in times .items ():
1195
+ print (f"{ k } : { v :.3e} ms" )
0 commit comments