Skip to content

Commit ebcf044

Browse files
authored
Fix race condition when JitSpec loads the library (#1467)
1 parent cc46992 commit ebcf044

File tree

1 file changed

+30
-11
lines changed

1 file changed

+30
-11
lines changed

flashinfer/jit/core.py

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os
44
import re
55
import warnings
6+
from contextlib import nullcontext
67
from pathlib import Path
78
from typing import List, Optional, Sequence, Union
89

@@ -88,14 +89,22 @@ def jit_library_path(self) -> Path:
8889
return jit_env.FLASHINFER_JIT_DIR / self.name / f"{self.name}.so"
8990

9091
def get_library_path(self) -> Path:
91-
if self.aot_path.exists():
92+
if self.is_aot:
9293
return self.aot_path
9394
return self.jit_library_path
9495

9596
@property
9697
def aot_path(self) -> Path:
9798
return jit_env.FLASHINFER_AOT_DIR / self.name / f"{self.name}.so"
9899

100+
@property
101+
def is_aot(self) -> bool:
102+
return self.aot_path.exists()
103+
104+
@property
105+
def lock_path(self) -> Path:
106+
return get_tmpdir() / f"{self.name}.lock"
107+
99108
def write_ninja(self) -> None:
100109
ninja_path = self.ninja_path
101110
ninja_path.parent.mkdir(parents=True, exist_ok=True)
@@ -110,18 +119,14 @@ def write_ninja(self) -> None:
110119
)
111120
write_if_different(ninja_path, content)
112121

113-
def build(self, verbose: bool) -> None:
114-
tmpdir = get_tmpdir()
115-
with FileLock(tmpdir / f"{self.name}.lock", thread_local=False):
122+
def build(self, verbose: bool, need_lock: bool = True) -> None:
123+
lock = (
124+
FileLock(self.lock_path, thread_local=False) if need_lock else nullcontext()
125+
)
126+
with lock:
116127
run_ninja(jit_env.FLASHINFER_JIT_DIR, self.ninja_path, verbose)
117128

118-
def build_and_load(self, class_name: str = None):
119-
if self.aot_path.exists():
120-
so_path = self.aot_path
121-
else:
122-
so_path = self.jit_library_path
123-
verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1"
124-
self.build(verbose)
129+
def load(self, so_path: Path, class_name: str = None):
125130
load_class = class_name is not None
126131
loader = torch.classes if load_class else torch.ops
127132
loader.load_library(so_path)
@@ -130,6 +135,20 @@ def build_and_load(self, class_name: str = None):
130135
return cls
131136
return getattr(loader, self.name)
132137

138+
def build_and_load(self, class_name: str = None):
139+
if self.is_aot:
140+
return self.load(self.aot_path, class_name)
141+
142+
# Guard both build and load with the same lock to avoid race condition
143+
# where another process is building the library and removes the .so file.
144+
with FileLock(self.lock_path, thread_local=False):
145+
so_path = self.jit_library_path
146+
verbose = os.environ.get("FLASHINFER_JIT_VERBOSE", "0") == "1"
147+
self.build(verbose, need_lock=False)
148+
result = self.load(so_path, class_name)
149+
150+
return result
151+
133152

134153
def gen_jit_spec(
135154
name: str,

0 commit comments

Comments
 (0)