|
21 | 21 | from graph_net.torch.backend.inductor_backend import InductorBackend |
22 | 22 | from graph_net.torch.backend.tensorrt_backend import TensorRTBackend |
23 | 23 | from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend |
24 | | -from graph_net.torch.backend.flaggems_backend import FlagGemsBackend |
25 | 24 | from graph_net.torch.backend.nope_backend import NopeBackend |
26 | 25 | from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend |
27 | 26 | from graph_net.torch.backend.range_decomposer_validator_backend import ( |
|
41 | 40 | "nope": NopeBackend(), |
42 | 41 | "unstable_to_stable": UnstableToStableBackend(), |
43 | 42 | "range_decomposer_validator": RangeDecomposerValidatorBackend(), |
44 | | - "flaggems": FlagGemsBackend(), |
45 | 43 | } |
46 | 44 |
|
47 | 45 |
|
@@ -221,6 +219,14 @@ def test_single_model(args): |
221 | 219 | compiled_types = [] |
222 | 220 | compiled_time_stats = {} |
223 | 221 |
|
| 222 | + if args.operator_lib == "flaggems": |
| 223 | + try: |
| 224 | + import flag_gems |
| 225 | + except ImportError: |
| 226 | + flag_gems = None |
| 227 | + |
| 228 | + flag_gems.enable() |
| 229 | + |
224 | 230 | try: |
225 | 231 | compiled_model = compiler(model) |
226 | 232 | torch.manual_seed(runtime_seed) |
@@ -370,6 +376,7 @@ def test_multi_models(args): |
370 | 376 | f"-m graph_net.torch.{module_name}", |
371 | 377 | f"--model-path {model_path}", |
372 | 378 | f"--compiler {args.compiler}", |
| 379 | + f"--operator-lib {args.operator_lib}", |
373 | 380 | f"--device {args.device}", |
374 | 381 | f"--warmup {args.warmup}", |
375 | 382 | f"--trials {args.trials}", |
@@ -418,6 +425,13 @@ def main(args): |
418 | 425 | default="inductor", |
419 | 426 | help="Path to customized compiler python file", |
420 | 427 | ) |
| 428 | + parser.add_argument( |
| 429 | + "--operator-lib", |
| 430 | + type=str, |
| 431 | + required=False, |
| 432 | + default="default", |
| 433 | + help="Customized operator library (eg. default, flaggems)", |
| 434 | + ) |
421 | 435 | parser.add_argument( |
422 | 436 | "--device", |
423 | 437 | type=str, |
|
0 commit comments