Skip to content

Commit 9d694b1

Browse files
authored
Fix gemm tuner error mi350 (ROCm#1313)
* workaround-retry tuning when encounter invalid pointer * workaround-retry tuning when encounter invalid pointer * fix lint error * Update gemm_tuner.py em timeout
1 parent da9ce41 commit 9d694b1

File tree

2 files changed

+98
-16
lines changed

2 files changed

+98
-16
lines changed

gradlib/gradlib/GemmTuner.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ def __init__(
148148
profile_file="",
149149
# splitK=None,
150150
):
151+
torch.cuda.empty_cache()
151152
self.m = m
152153
self.k = k
153154
self.n = n
@@ -166,19 +167,19 @@ def __init__(
166167
self.rocb_sols = []
167168
self.rtol = 1e-2
168169
self.atol = 1e-2
169-
self.ref = self.get_gemm_ref()
170+
# self.ref = self.get_gemm_ref()
170171
self.check_err_ratio = err_ratio
171172
self.splitK = None
172173
self.profile_file = profile_file
173-
self.start = torch.cuda.Event(enable_timing=True)
174-
self.end = torch.cuda.Event(enable_timing=True)
174+
# self.start = torch.cuda.Event(enable_timing=True)
175+
# self.end = torch.cuda.Event(enable_timing=True)
175176
# prefer hipblaslt unless rocblas time is less than this
176177
# ratio of hipblaslt time
177178
self.hipb_prefer_ratio = 0.995
178179
self.rocblas_decode = rocblas_decode
179180
self.mp = mp
180-
self.inbpe = self.inp.element_size()
181-
self.outbpe = self.ref.element_size()
181+
# self.inbpe = self.inp.element_size()
182+
# self.outbpe = self.ref.element_size()
182183
self.asm_map = {}
183184

184185
def find_hipblas_sols(self):
@@ -379,10 +380,15 @@ def hipb_time_all_sols(self, fast_mode=0, top_sols=0):
379380
if fast_mode == 1:
380381
self.hipb_gtimedf = self.save_topn_result(ret, fast_mode, "hipblaslt")
381382
return []
383+
print(f">>> hipblaslt top solutions, Fast Mode {fast_mode}")
382384
return ret
383385

384386
def save_topn_result(self, rets, fast_mode, libtype):
385387
results = []
388+
if not rets:
389+
return pd.DataFrame(
390+
columns=["solidx", "gtimems", "splitK", "err_ratio", "kernelName"]
391+
)
386392
for info, us, err_ratio in rets:
387393
res_one = []
388394
solidx = info[1]
@@ -478,8 +484,11 @@ def rocb_time_all_sols(self, fast_mode=0, top_sols=0):
478484
self.atol,
479485
)
480486
)
481-
in_data = [(len(solutions), ())]
482-
ret = mp_tuner(task, in_data, self.mp, fast_mode == 1)
487+
if task:
488+
in_data = [(len(solutions), ())]
489+
ret = mp_tuner(task, in_data, self.mp, fast_mode == 1)
490+
else:
491+
ret = []
483492
if fast_mode == 1:
484493
self.rocb_gtimedf = self.save_topn_result(ret, fast_mode, "rocblas")
485494
return []
@@ -519,6 +528,28 @@ def run_solutions(self):
519528
rets = self.run_best_solutions()
520529
return rets
521530

531+
def cleanup(self):
532+
if hasattr(self, "inp"):
533+
del self.inp
534+
if hasattr(self, "weights"):
535+
del self.weights
536+
if hasattr(self, "bias") and self.bias is not None:
537+
del self.bias
538+
if hasattr(self, "blob"):
539+
cpu_blob = self.blob.cpu()
540+
del cpu_blob
541+
542+
def cleanup(self):
543+
if hasattr(self, "inp"):
544+
del self.inp
545+
if hasattr(self, "weights"):
546+
del self.weights
547+
if hasattr(self, "bias") and self.bias is not None:
548+
del self.bias
549+
if hasattr(self, "blob"):
550+
cpu_blob = self.blob.cpu()
551+
del cpu_blob
552+
522553

523554
class GemmTuner(GemmCommonTuner):
524555
ARG_DEFAULTS = {
@@ -597,6 +628,7 @@ def __init__(
597628

598629
self.hipb_prefer_ratio = 0.995
599630
self.cu_num = self.get_cu_num()
631+
self.gemmobj = None
600632

601633
def calculate_perf(
602634
self,
@@ -708,7 +740,6 @@ def tune(self, untunedf, tunedf, args):
708740
ds = df.loc[i, :]
709741
indtype = ds["dtype"]
710742
outdtype = ds["outdtype"]
711-
712743
gemmobj = Gemm(
713744
ds["M"],
714745
ds["N"],
@@ -722,9 +753,11 @@ def tune(self, untunedf, tunedf, args):
722753
err_ratio=args.errRatio,
723754
profile_file=args.profile_file,
724755
)
756+
725757
ret.extend(gemmobj.run_solutions())
758+
gemmobj.cleanup()
726759
del gemmobj
727-
torch.cuda.empty_cache()
760+
728761
return ret
729762

730763
def processResult(self, rets, fast_mode):
@@ -819,7 +852,6 @@ def post_process(self, rets, args, topk=-1, fast_mode=False):
819852
if best_gtimedfs.empty:
820853
best_gtimedfs = resultdf1
821854
else:
822-
print("concat ", resultdf1)
823855
best_gtimedfs = pd.concat([best_gtimedfs, resultdf1], ignore_index=True)
824856

825857
print(f"{key} >>> Fastest Solution is \n {resultdf1}", flush=True)

gradlib/gradlib/gemm_tuner.py

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
from GemmTuner import GemmTuner
2929

3030
import time
31+
import multiprocessing as mp
32+
import gc
3133

3234
aiter.rocb_create_extension()
3335
aiter.hipb_create_extension()
@@ -89,7 +91,7 @@ def load_input_gemms(input_file):
8991
return
9092

9193

92-
if __name__ == "__main__":
94+
def runGemmTuner():
9395
gtuner = GemmTuner()
9496
ext_group = gtuner.parser.add_argument_group("extra parameters")
9597
ext_group.add_argument(
@@ -117,7 +119,6 @@ def load_input_gemms(input_file):
117119
help="Tensor parallelism to be used.",
118120
)
119121
args = gtuner.parse_args()
120-
121122
if args.outdtype is None:
122123
args.outdtype = args.indtype
123124
indtype = get_dtype(args.indtype)
@@ -130,9 +131,7 @@ def load_input_gemms(input_file):
130131
print(">>> Warning! NO MODEL SPECIFIED. Tuning for LL2 13B TP1")
131132
# LL2 13B sizes
132133
mksets = [(15360, 5120), (5120, 5120), (27648, 5120), (5120, 13824)]
133-
134134
gtuner.add_gemm(m=32000, n=1, k=5120, indtype=indtype) # logits gemm
135-
136135
else:
137136
mksets, hidden_size, dtype = generate_mk_sets(args.model_dir, args.tp)
138137
gtuner.add_gemm(
@@ -141,11 +140,62 @@ def load_input_gemms(input_file):
141140
k=hidden_size,
142141
indtype=dtype,
143142
) # TODO: Handle cases where vocab_size is not divisible by tp
144-
145143
for n in sorted(nsets):
146144
for m, k in mksets:
147145
gtuner.add_gemm(m, n, k, indtype=dtype)
148146
gtuner.untunedf.to_csv("./tmp_untuned.csv", index=False)
149147
args.untune_file = "./tmp_untuned.csv"
150-
151148
gtuner.run(args)
149+
150+
151+
def clean():
152+
gc.collect()
153+
if torch.cuda.is_available():
154+
torch.cuda.empty_cache()
155+
if hasattr(torch.cuda, "memory_allocated"):
156+
torch.cuda.synchronize()
157+
try:
158+
if hasattr(mp, "resource_tracker"):
159+
mp.resource_tracker.ensure_running()
160+
# clean leaked semaphore objects
161+
if hasattr(mp.resource_tracker, "_CLEANUP_FUNCS"):
162+
# be careful
163+
for name in list(mp.resource_tracker._CLEANUP_FUNCS.keys()):
164+
try:
165+
mp.resource_tracker._CLEANUP_FUNCS.pop(name)()
166+
except:
167+
pass
168+
except Exception as e:
169+
print(f"Resource cleanup warning: {e}")
170+
171+
172+
if __name__ == "__main__":
173+
retries = 0
174+
MAX_TRY = 30
175+
mp.set_start_method("spawn", force=True)
176+
while retries <= MAX_TRY:
177+
try:
178+
process = mp.Process(target=runGemmTuner, args=(), daemon=False)
179+
process.start()
180+
process.join()
181+
if process.exitcode != 0:
182+
time.sleep(0.5 * retries)
183+
print(
184+
"!Error when run GemmTuner process exitcode is ", process.exitcode
185+
)
186+
clean()
187+
retries += 1
188+
else:
189+
break
190+
except Exception as e:
191+
print(f"Process creation failed: {e}")
192+
retries += 1
193+
clean()
194+
time.sleep(1)
195+
finally:
196+
if process and process.is_alive():
197+
process.terminate()
198+
process.join(timeout=5)
199+
200+
clean()
201+
print(f"retried num is {retries}")

0 commit comments

Comments
 (0)