Skip to content

Commit da94685

Browse files
committed
Add op lib selection in test device
1 parent 7d63d42 commit da94685

File tree

2 files changed

+42
-0
lines changed

2 files changed

+42
-0
lines changed

graph_net/torch/test_reference_device.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,15 @@ def get_reference_output_path(reference_dir, model_path):
2727
return os.path.join(reference_dir, f"{model_name}.pth")
2828

2929

30+
def register_op_lib(op_lib):
31+
if op_lib == "flaggems":
32+
import flag_gems
33+
34+
flag_gems.enable()
35+
else:
36+
pass
37+
38+
3039
def test_single_model(args):
3140
ref_log = get_reference_log_path(args.reference_dir, args.model_path)
3241
ref_dump = get_reference_output_path(args.reference_dir, args.model_path)
@@ -51,6 +60,10 @@ def test_single_model(args):
5160
test_compiler.get_compile_framework_version(args),
5261
)
5362

63+
test_compiler_util.print_with_log_prompt(
64+
"[Config] op_lib:", args.op_lib, args.log_prompt
65+
)
66+
5467
success = False
5568
time_stats = {}
5669
try:
@@ -99,6 +112,7 @@ def test_multi_models(args):
99112
f"--model-path {model_path}",
100113
f"--compiler {args.compiler}",
101114
f"--device {args.device}",
115+
f"--op-lib {args.op_lib}",
102116
f"--warmup {args.warmup}",
103117
f"--trials {args.trials}",
104118
f"--log-prompt {args.log_prompt}",
@@ -136,6 +150,7 @@ def main(args):
136150
ref_dump_dir.mkdir(parents=True, exist_ok=True)
137151

138152
if path_utils.is_single_model_dir(args.model_path):
153+
register_op_lib(args.op_lib)
139154
test_single_model(args)
140155
else:
141156
test_multi_models(args)
@@ -163,6 +178,13 @@ def main(args):
163178
default="cuda",
164179
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
165180
)
181+
parser.add_argument(
182+
"--op-lib",
183+
type=str,
184+
required=False,
185+
default="default",
186+
help="Customized operator library (eg. default, flaggems)",
187+
)
166188
parser.add_argument(
167189
"--warmup", type=int, required=False, default=5, help="Number of warmup steps"
168190
)

graph_net/torch/test_target_device.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def test_multi_models(args):
143143
f"-m graph_net.torch.{module_name}",
144144
f"--model-path {model_path}",
145145
f"--device {args.device}",
146+
f"--op-lib {args.op_lib}",
146147
f"--log-prompt {args.log_prompt}",
147148
f"--reference-dir {args.reference_dir}",
148149
]
@@ -167,6 +168,18 @@ def main(args):
167168
assert args.device in ["cuda", "dcu", "xpu", "cpu"]
168169

169170
if path_utils.is_single_model_dir(args.model_path):
171+
if args.op_lib == "origin":
172+
ref_log = test_reference_device.get_reference_log_path(
173+
args.reference_dir, args.model_path
174+
)
175+
config = parse_config_from_reference_log(ref_log)
176+
vars(args)["op_lib"] = config.get("op_lib")
177+
test_compiler_util.print_with_log_prompt(
178+
"[Config] op_lib:", args.op_lib, args.log_prompt
179+
)
180+
else:
181+
test_reference_device.register_op_lib(args.op_lib)
182+
170183
args = update_args_and_set_seed(args, args.model_path)
171184
test_single_model(args)
172185
else:
@@ -194,6 +207,13 @@ def main(args):
194207
default="cuda",
195208
help="Device for testing the compiler (e.g., 'cpu' or 'cuda')",
196209
)
210+
parser.add_argument(
211+
"--op-lib",
212+
type=str,
213+
required=False,
214+
default="default",
215+
help="Customized operator library (eg. default, flaggems or origin)",
216+
)
197217
parser.add_argument(
198218
"--log-prompt",
199219
type=str,

0 commit comments

Comments
 (0)