Skip to content

Commit bd4f808

Browse files
authored
Fix flag gems tests and imports (#35)
1 parent 9277c60 commit bd4f808

File tree

2 files changed

+7
-10
lines changed

2 files changed

+7
-10
lines changed

BackendBench/backends.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,15 @@
22
import importlib.util
33
import logging
44
from typing import Dict, Callable, List
5+
import torch
56

67
logger = logging.getLogger(__name__)
78

9+
try:
10+
import flag_gems
11+
except ImportError:
12+
flag_gems = None
13+
814

915
class Backend:
1016
def __init__(self, name):
@@ -67,8 +73,6 @@ def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable:
6773

6874
def _find_pytorch_op(self, op_name: str):
6975
"""Map operation name to PyTorch operation."""
70-
import torch
71-
7276
# Try common patterns
7377
try:
7478
return getattr(torch.ops.aten, op_name).default
@@ -106,14 +110,10 @@ def __contains__(self, key):
106110

107111
def _flag_gems_softmax(*args, **kwargs):
108112
# half_to_float is not supported in flag_gems
109-
import flag_gems
110-
111113
return flag_gems.ops.softmax(*args[:-1], **kwargs)
112114

113115

114116
def _flag_gems_layernorm(*args, **kwargs):
115-
import flag_gems
116-
117117
x, m, v = flag_gems.ops.layer_norm(*args[:-1], **kwargs)
118118
mv_shape = [*x.shape[:-1], 1]
119119
return x, m.view(*mv_shape), v.view(*mv_shape)
@@ -122,9 +122,6 @@ def _flag_gems_layernorm(*args, **kwargs):
122122
class FlagGemsBackend(Backend):
123123
def __init__(self) -> None:
124124
super().__init__("flaggems")
125-
import flag_gems
126-
import torch
127-
128125
self.ops = {
129126
torch.ops.aten.abs.default: flag_gems.ops.abs,
130127
torch.ops.aten.abs_.default: flag_gems.ops.abs_,

test/test_backends.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_flag_gems_backend_contains_op(self, mock_flag_gems):
7373
@patch("BackendBench.backends.flag_gems")
7474
def test_flag_gems_backend_getitem(self, mock_flag_gems):
7575
mock_abs_impl = Mock()
76-
mock_flag_gems.abs = mock_abs_impl
76+
mock_flag_gems.ops.abs = mock_abs_impl
7777

7878
backend = FlagGemsBackend()
7979

0 commit comments

Comments
 (0)