2
2
import importlib .util
3
3
import logging
4
4
from typing import Dict , Callable , List
5
+ import torch
5
6
6
7
logger = logging .getLogger (__name__ )
7
8
9
+ try :
10
+ import flag_gems
11
+ except ImportError :
12
+ flag_gems = None
13
+
8
14
9
15
class Backend :
10
16
def __init__ (self , name ):
@@ -67,8 +73,6 @@ def _load_kernel_from_file(self, file_path: str, op_name: str) -> Callable:
67
73
68
74
def _find_pytorch_op (self , op_name : str ):
69
75
"""Map operation name to PyTorch operation."""
70
- import torch
71
-
72
76
# Try common patterns
73
77
try :
74
78
return getattr (torch .ops .aten , op_name ).default
@@ -106,14 +110,10 @@ def __contains__(self, key):
106
110
107
111
def _flag_gems_softmax (* args , ** kwargs ):
108
112
# half_to_float is not supported in flag_gems
109
- import flag_gems
110
-
111
113
return flag_gems .ops .softmax (* args [:- 1 ], ** kwargs )
112
114
113
115
114
116
def _flag_gems_layernorm (* args , ** kwargs ):
115
- import flag_gems
116
-
117
117
x , m , v = flag_gems .ops .layer_norm (* args [:- 1 ], ** kwargs )
118
118
mv_shape = [* x .shape [:- 1 ], 1 ]
119
119
return x , m .view (* mv_shape ), v .view (* mv_shape )
@@ -122,9 +122,6 @@ def _flag_gems_layernorm(*args, **kwargs):
122
122
class FlagGemsBackend (Backend ):
123
123
def __init__ (self ) -> None :
124
124
super ().__init__ ("flaggems" )
125
- import flag_gems
126
- import torch
127
-
128
125
self .ops = {
129
126
torch .ops .aten .abs .default : flag_gems .ops .abs ,
130
127
torch .ops .aten .abs_ .default : flag_gems .ops .abs_ ,
0 commit comments