Skip to content

Commit d039f17

Browse files
committed
change operator_lib selection to a new param
1 parent 9777aef commit d039f17

File tree

2 files changed

+16
-19
lines changed

2 files changed

+16
-19
lines changed

graph_net/torch/backend/flaggems_backend.py

Lines changed: 0 additions & 17 deletions
This file was deleted.

graph_net/torch/test_compiler.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
from graph_net.torch.backend.inductor_backend import InductorBackend
2222
from graph_net.torch.backend.tensorrt_backend import TensorRTBackend
2323
from graph_net.torch.backend.blade_disc_backend import BladeDISCBackend
24-
from graph_net.torch.backend.flaggems_backend import FlagGemsBackend
2524
from graph_net.torch.backend.nope_backend import NopeBackend
2625
from graph_net.torch.backend.unstable_to_stable_backend import UnstableToStableBackend
2726
from graph_net.torch.backend.range_decomposer_validator_backend import (
@@ -41,7 +40,6 @@
4140
"nope": NopeBackend(),
4241
"unstable_to_stable": UnstableToStableBackend(),
4342
"range_decomposer_validator": RangeDecomposerValidatorBackend(),
44-
"flaggems": FlagGemsBackend(),
4543
}
4644

4745

@@ -221,6 +219,14 @@ def test_single_model(args):
221219
compiled_types = []
222220
compiled_time_stats = {}
223221

222+
if args.operator_lib == "flaggems":
223+
try:
224+
import flag_gems
225+
except ImportError:
226+
flag_gems = None
227+
228+
flag_gems.enable()
229+
224230
try:
225231
compiled_model = compiler(model)
226232
torch.manual_seed(runtime_seed)
@@ -370,6 +376,7 @@ def test_multi_models(args):
370376
f"-m graph_net.torch.{module_name}",
371377
f"--model-path {model_path}",
372378
f"--compiler {args.compiler}",
379+
f"--operator-lib {args.operator_lib}",
373380
f"--device {args.device}",
374381
f"--warmup {args.warmup}",
375382
f"--trials {args.trials}",
@@ -418,6 +425,13 @@ def main(args):
418425
default="inductor",
419426
help="Path to customized compiler python file",
420427
)
428+
parser.add_argument(
429+
"--operator-lib",
430+
type=str,
431+
required=False,
432+
default="default",
433+
help="Customized operator library (eg. default, flaggems)",
434+
)
421435
parser.add_argument(
422436
"--device",
423437
type=str,

0 commit comments

Comments
 (0)