3
3
import os
4
4
import re
5
5
import warnings
6
+ from contextlib import nullcontext
6
7
from pathlib import Path
7
8
from typing import List , Optional , Sequence , Union
8
9
@@ -88,14 +89,22 @@ def jit_library_path(self) -> Path:
88
89
return jit_env .FLASHINFER_JIT_DIR / self .name / f"{ self .name } .so"
89
90
90
91
def get_library_path (self ) -> Path :
91
- if self .aot_path . exists () :
92
+ if self .is_aot :
92
93
return self .aot_path
93
94
return self .jit_library_path
94
95
95
96
@property
96
97
def aot_path (self ) -> Path :
97
98
return jit_env .FLASHINFER_AOT_DIR / self .name / f"{ self .name } .so"
98
99
100
+ @property
101
+ def is_aot (self ) -> bool :
102
+ return self .aot_path .exists ()
103
+
104
+ @property
105
+ def lock_path (self ) -> Path :
106
+ return get_tmpdir () / f"{ self .name } .lock"
107
+
99
108
def write_ninja (self ) -> None :
100
109
ninja_path = self .ninja_path
101
110
ninja_path .parent .mkdir (parents = True , exist_ok = True )
@@ -110,18 +119,14 @@ def write_ninja(self) -> None:
110
119
)
111
120
write_if_different (ninja_path , content )
112
121
113
- def build (self , verbose : bool ) -> None :
114
- tmpdir = get_tmpdir ()
115
- with FileLock (tmpdir / f"{ self .name } .lock" , thread_local = False ):
122
+ def build (self , verbose : bool , need_lock : bool = True ) -> None :
123
+ lock = (
124
+ FileLock (self .lock_path , thread_local = False ) if need_lock else nullcontext ()
125
+ )
126
+ with lock :
116
127
run_ninja (jit_env .FLASHINFER_JIT_DIR , self .ninja_path , verbose )
117
128
118
- def build_and_load (self , class_name : str = None ):
119
- if self .aot_path .exists ():
120
- so_path = self .aot_path
121
- else :
122
- so_path = self .jit_library_path
123
- verbose = os .environ .get ("FLASHINFER_JIT_VERBOSE" , "0" ) == "1"
124
- self .build (verbose )
129
+ def load (self , so_path : Path , class_name : str = None ):
125
130
load_class = class_name is not None
126
131
loader = torch .classes if load_class else torch .ops
127
132
loader .load_library (so_path )
@@ -130,6 +135,20 @@ def build_and_load(self, class_name: str = None):
130
135
return cls
131
136
return getattr (loader , self .name )
132
137
138
+ def build_and_load (self , class_name : str = None ):
139
+ if self .is_aot :
140
+ return self .load (self .aot_path , class_name )
141
+
142
+ # Guard both build and load with the same lock to avoid race condition
143
+ # where another process is building the library and removes the .so file.
144
+ with FileLock (self .lock_path , thread_local = False ):
145
+ so_path = self .jit_library_path
146
+ verbose = os .environ .get ("FLASHINFER_JIT_VERBOSE" , "0" ) == "1"
147
+ self .build (verbose , need_lock = False )
148
+ result = self .load (so_path , class_name )
149
+
150
+ return result
151
+
133
152
134
153
def gen_jit_spec (
135
154
name : str ,
0 commit comments