@@ -148,6 +148,7 @@ def __init__(
148148 profile_file = "" ,
149149 # splitK=None,
150150 ):
151+ torch .cuda .empty_cache ()
151152 self .m = m
152153 self .k = k
153154 self .n = n
@@ -166,19 +167,19 @@ def __init__(
166167 self .rocb_sols = []
167168 self .rtol = 1e-2
168169 self .atol = 1e-2
169- self .ref = self .get_gemm_ref ()
170+ # self.ref = self.get_gemm_ref()
170171 self .check_err_ratio = err_ratio
171172 self .splitK = None
172173 self .profile_file = profile_file
173- self .start = torch .cuda .Event (enable_timing = True )
174- self .end = torch .cuda .Event (enable_timing = True )
174+ # self.start = torch.cuda.Event(enable_timing=True)
175+ # self.end = torch.cuda.Event(enable_timing=True)
175176 # prefer hipblaslt unless rocblas time is less than this
176177 # ratio of hipblaslt time
177178 self .hipb_prefer_ratio = 0.995
178179 self .rocblas_decode = rocblas_decode
179180 self .mp = mp
180- self .inbpe = self .inp .element_size ()
181- self .outbpe = self .ref .element_size ()
181+ # self.inbpe = self.inp.element_size()
182+ # self.outbpe = self.ref.element_size()
182183 self .asm_map = {}
183184
184185 def find_hipblas_sols (self ):
@@ -379,10 +380,15 @@ def hipb_time_all_sols(self, fast_mode=0, top_sols=0):
379380 if fast_mode == 1 :
380381 self .hipb_gtimedf = self .save_topn_result (ret , fast_mode , "hipblaslt" )
381382 return []
383+ print (f">>> hipblaslt top solutions, Fast Mode { fast_mode } " )
382384 return ret
383385
384386 def save_topn_result (self , rets , fast_mode , libtype ):
385387 results = []
388+ if not rets :
389+ return pd .DataFrame (
390+ columns = ["solidx" , "gtimems" , "splitK" , "err_ratio" , "kernelName" ]
391+ )
386392 for info , us , err_ratio in rets :
387393 res_one = []
388394 solidx = info [1 ]
@@ -478,8 +484,11 @@ def rocb_time_all_sols(self, fast_mode=0, top_sols=0):
478484 self .atol ,
479485 )
480486 )
481- in_data = [(len (solutions ), ())]
482- ret = mp_tuner (task , in_data , self .mp , fast_mode == 1 )
487+ if task :
488+ in_data = [(len (solutions ), ())]
489+ ret = mp_tuner (task , in_data , self .mp , fast_mode == 1 )
490+ else :
491+ ret = []
483492 if fast_mode == 1 :
484493 self .rocb_gtimedf = self .save_topn_result (ret , fast_mode , "rocblas" )
485494 return []
@@ -519,6 +528,28 @@ def run_solutions(self):
519528 rets = self .run_best_solutions ()
520529 return rets
521530
531+ def cleanup (self ):
532+ if hasattr (self , "inp" ):
533+ del self .inp
534+ if hasattr (self , "weights" ):
535+ del self .weights
536+ if hasattr (self , "bias" ) and self .bias is not None :
537+ del self .bias
538+ if hasattr (self , "blob" ):
539+ cpu_blob = self .blob .cpu ()
540+ del cpu_blob
541+
542+ def cleanup (self ):
543+ if hasattr (self , "inp" ):
544+ del self .inp
545+ if hasattr (self , "weights" ):
546+ del self .weights
547+ if hasattr (self , "bias" ) and self .bias is not None :
548+ del self .bias
549+ if hasattr (self , "blob" ):
550+ cpu_blob = self .blob .cpu ()
551+ del cpu_blob
552+
522553
523554class GemmTuner (GemmCommonTuner ):
524555 ARG_DEFAULTS = {
@@ -597,6 +628,7 @@ def __init__(
597628
598629 self .hipb_prefer_ratio = 0.995
599630 self .cu_num = self .get_cu_num ()
631+ self .gemmobj = None
600632
601633 def calculate_perf (
602634 self ,
@@ -708,7 +740,6 @@ def tune(self, untunedf, tunedf, args):
708740 ds = df .loc [i , :]
709741 indtype = ds ["dtype" ]
710742 outdtype = ds ["outdtype" ]
711-
712743 gemmobj = Gemm (
713744 ds ["M" ],
714745 ds ["N" ],
@@ -722,9 +753,11 @@ def tune(self, untunedf, tunedf, args):
722753 err_ratio = args .errRatio ,
723754 profile_file = args .profile_file ,
724755 )
756+
725757 ret .extend (gemmobj .run_solutions ())
758+ gemmobj .cleanup ()
726759 del gemmobj
727- torch . cuda . empty_cache ()
760+
728761 return ret
729762
730763 def processResult (self , rets , fast_mode ):
@@ -819,7 +852,6 @@ def post_process(self, rets, args, topk=-1, fast_mode=False):
819852 if best_gtimedfs .empty :
820853 best_gtimedfs = resultdf1
821854 else :
822- print ("concat " , resultdf1 )
823855 best_gtimedfs = pd .concat ([best_gtimedfs , resultdf1 ], ignore_index = True )
824856
825857 print (f"{ key } >>> Fastest Solution is \n { resultdf1 } " , flush = True )
0 commit comments