Skip to content

Commit 97ac929

Browse files
authored
refactor fmoe tuner profile (ROCm#1614)
* refactor fmoe profile data, and log * Update fused_moe.py
1 parent d263d44 commit 97ac929

File tree

4 files changed

+33
-7
lines changed

4 files changed

+33
-7
lines changed

aiter/fused_moe.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,9 @@ def MainFunc():
615615
)
616616

617617
def FinalFunc():
618+
logger.info(
619+
f"[Hint] tuned configs are saved in {tune_file}, you can set AITER_CONFIG_FMOE to this file to use tuned configs"
620+
)
618621
logger.info("\033[0m")
619622

620623
# cfg = cfg_2stages.get(keys, None)

aiter/utility/base_tuner.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(self, name, key, resultList, description=None):
5858
self.failed = pd.DataFrame(columns=self.columns)
5959

6060
self.remain_untuned = pd.DataFrame(columns=self.keys)
61+
self.sort_keys = key
6162
self.start_time = 0
6263

6364
def get_arg_defaults(self):
@@ -305,7 +306,14 @@ def post_process(self, rets, args, topk=-1, fast_mode=False):
305306
"""post process, post process all results to return topk results"""
306307
rets = list(rets)
307308
if args.profile_file != "":
309+
if args.verbose:
310+
logger.info(f"saving profile to {args.profile_file}")
308311
profiledf = self.result_to_df(sorted(rets, key=itemgetter(0)))
312+
if os.path.exists(args.profile_file):
313+
old_df = pd.read_csv(args.profile_file)
314+
else:
315+
old_df = pd.DataFrame(columns=self.columns)
316+
profiledf = pd.concat([old_df, profiledf], ignore_index=True)
309317
profiledf.to_csv(args.profile_file, index=False, na_rep="Null")
310318

311319
if fast_mode or topk == -1:
@@ -394,7 +402,7 @@ def run(self, args, fast_mode=False):
394402
logger.info(f"args: {args}")
395403
if len(self.untunedf) == 0:
396404
# self.update_tflops_bw(args.tune_file)
397-
self.sortResults(output_file, args.sort, self.keys)
405+
self.sortResults(output_file, args.sort, self.sort_keys)
398406
logger.info(
399407
f"no shapes to be tuned, skip tuning, tuned file is {args.tune_file}"
400408
)
@@ -426,7 +434,7 @@ def run(self, args, fast_mode=False):
426434
logger.info(
427435
f"tune result is none or all shape is tuned in {args.tune_file}!"
428436
)
429-
self.sortResults(output_file, args.sort, self.keys)
437+
self.sortResults(output_file, args.sort, self.sort_keys)
430438
except KeyboardInterrupt:
431439
tuning_status = "Interrupted"
432440
logger.error(
@@ -460,6 +468,14 @@ def __init__(
460468
description=None,
461469
):
462470
super().__init__(name, key, resultList, description)
471+
# Swap M and N positions to ensure N comes before M
472+
self.sort_keys = list(key)
473+
m_idx = self.sort_keys.index("M")
474+
n_idx = self.sort_keys.index("N")
475+
self.sort_keys[m_idx], self.sort_keys[n_idx] = (
476+
self.sort_keys[n_idx],
477+
self.sort_keys[m_idx],
478+
)
463479

464480
def pre_process(self, args):
465481
if args.all:

csrc/cktile_gemm_a8w8_bpreshuffle/gen_instances.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def get_tune_dict(tune_dict_csv):
240240
N = tune_df.loc[i, "N"]
241241
K = tune_df.loc[i, "K"]
242242
kid = tune_df.loc[i, "kernelId"]
243-
if kid < 0 or kid > len(kernels_list):
243+
if kid < 0 or kid >= len(kernels_list):
244244
print(f"[Warning]: kernelId {kid} is out of range, skip it")
245245
continue
246246
tune_dict[(M, N, K)] = kernels_list[kid]

hsa/gfx942/fmoe_2stages/tune.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ class FmoeTuner(TunerCommon):
4949
"untune_file": "aiter/configs/untuned_fmoe.csv",
5050
"errRatio": 0.5,
5151
"batch": 100,
52-
"profile_file": "aiter/configs/profile_fmoe.csv", # for all results
52+
"profile_file": "", # for all results
5353
}
5454

5555
def _setup_specific_arguments(self):
@@ -1958,6 +1958,13 @@ def post_process(self, results, args, topk=-1, fast_mode=False):
19581958
profileDF.drop(["tflops1", "tflops2", "bw1", "bw2"], axis=1, inplace=True)
19591959
profileDF["err1"] = profileDF["err1"].apply(lambda x: f"{x:.1%}")
19601960
profileDF["err2"] = profileDF["err2"].apply(lambda x: f"{x:.1%}")
1961+
if args.profile_file != "":
1962+
if os.path.exists(args.profile_file):
1963+
old_df = pd.read_csv(args.profile_file)
1964+
else:
1965+
old_df = pd.DataFrame(columns=self.columns)
1966+
tmpprofileDF = pd.concat([old_df, profileDF], ignore_index=True)
1967+
tmpprofileDF.to_csv(args.profile_file, index=False)
19611968
best_one = profileDF.loc[profileDF["us"].idxmin()].copy()
19621969
print(
19631970
f"Tuning result for {key} is {best_one['block_m'] ,best_one['kernelName1'], best_one['kernelName2'], best_one['err1'], best_one['err2'], best_one['run_1stage']} {best_one['us']} us, {best_one['tflops']} TFLOPS, {best_one['bw']} GB/s"
@@ -1971,12 +1978,12 @@ def post_process(self, results, args, topk=-1, fast_mode=False):
19711978
if len(prorfiles) > 0:
19721979
profile_result = pd.concat(prorfiles)
19731980
profile_result["err"] = profile_result["err"].apply(lambda x: f"{x:.1%}")
1981+
profile_file = f"aiter/configs/profile_fmoe.csv"
19741982
old_profile = self.get_tuned_gemm_list(
1975-
args.profile_file, profile_result.columns.tolist()
1983+
profile_file, profile_result.columns.tolist()
19761984
)
1977-
19781985
profile_result = pd.concat([old_profile, profile_result])
1979-
profile_result.to_csv(args.profile_file, index=False)
1986+
profile_result.to_csv(profile_file, index=False)
19801987
if len(bests) > 0:
19811988
return pd.concat(bests, axis=1).T
19821989
else:

0 commit comments

Comments
 (0)