Skip to content

Commit c08b529

Browse files
authored
Fix: Prevent race condition in cubin loader when file is being consumed (#1852)
1 parent da18466 commit c08b529

File tree

2 files changed

+124
-2
lines changed

2 files changed

+124
-2
lines changed

flashinfer/jit/cubin_loader.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,13 @@ def download_file(
7777
with lock:
7878
logger.info(f"Acquired lock for {local_path}")
7979

80+
temp_path = f"{local_path}.tmp"
81+
8082
# Handle local file copy
8183
if os.path.exists(source):
8284
try:
83-
shutil.copy(source, local_path)
85+
shutil.copy(source, temp_path)
86+
os.replace(temp_path, local_path) # Atomic rename
8487
logger.info(f"File copied successfully: {local_path}")
8588
return True
8689
except Exception as e:
@@ -93,9 +96,12 @@ def download_file(
9396
response = session.get(source, timeout=timeout)
9497
response.raise_for_status()
9598

96-
with open(local_path, "wb") as file:
99+
with open(temp_path, "wb") as file:
97100
file.write(response.content)
98101

102+
# Atomic rename to prevent readers from seeing partial writes
103+
os.replace(temp_path, local_path)
104+
99105
logger.info(
100106
f"File downloaded successfully: {source} -> {local_path}"
101107
)
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
"""
2+
Copyright (c) 2025 by FlashInfer team.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
"""
16+
17+
import os
18+
import tempfile
19+
from pathlib import Path
20+
from multiprocessing import Pool
21+
22+
23+
def worker_process(temp_dir):
24+
"""
25+
Worker function that each process executes.
26+
27+
Each process will:
28+
1. Set FLASHINFER_CUBIN_DIR environment variable
29+
2. Import and call get_cubin with the same target file
30+
3. Read the file from FLASHINFER_CUBIN_DIR
31+
4. Return the file content
32+
"""
33+
# Set environment variable for this process
34+
os.environ["FLASHINFER_CUBIN_DIR"] = temp_dir
35+
36+
# Import here to ensure FLASHINFER_CUBIN_DIR is set before module loads
37+
from flashinfer.artifacts import ArtifactPath, MetaInfoHash
38+
from flashinfer.jit.cubin_loader import get_cubin
39+
40+
# Define the target file - same for all processes
41+
include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include"
42+
header_name = "flashinferMetaInfo"
43+
44+
# Use get_cubin to get "flashinferMetaInfo.h"
45+
# Note: all processes target the same file name
46+
metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) # noqa: F841
47+
48+
# Read the file from FLASHINFER_CUBIN_DIR
49+
# NOTE(Zihao): instead of using metainfo, we directly read from the file path,
50+
# that aligns with how we compile the kernel.
51+
file_path = Path(temp_dir) / include_path / f"{header_name}.h"
52+
with open(file_path, "rb") as f:
53+
content = f.read()
54+
55+
return content
56+
57+
58+
def test_load_cubin_race_condition(num_iterations, num_processes):
59+
"""
60+
Test race condition when multiple processes concurrently call get_cubin
61+
for the same file.
62+
63+
Test steps:
64+
1. Set up a temporary FLASHINFER_CUBIN_DIR
65+
2. Launch multiple processes
66+
3. Each process calls get_cubin for the same target file
67+
4. Each process reads the downloaded file
68+
5. Verify all processes read the same content
69+
6. Repeat multiple times to increase chance of detecting race conditions
70+
71+
Args:
72+
num_iterations: Number of times to repeat the test
73+
num_processes: Number of concurrent processes per iteration
74+
"""
75+
import shutil
76+
77+
for iteration in range(num_iterations):
78+
# Create a temporary directory for FLASHINFER_CUBIN_DIR
79+
temp_dir = tempfile.mkdtemp(prefix="flashinfer_test_cubin_")
80+
81+
try:
82+
# Launch multiple processes concurrently
83+
with Pool(processes=num_processes) as pool:
84+
results = pool.map(worker_process, [temp_dir] * num_processes)
85+
86+
# Verify all processes read the same content
87+
assert len(results) == num_processes, (
88+
f"Expected {num_processes} results, got {len(results)}"
89+
)
90+
91+
# All results should be identical
92+
first_content = results[0]
93+
for i, content in enumerate(results):
94+
assert content == first_content, (
95+
f"Iteration {iteration + 1}/{num_iterations}, Process {i} read different content. "
96+
f"Expected length {len(first_content)}, got {len(content)}"
97+
)
98+
99+
if (iteration + 1) % 10 == 0 or iteration == 0:
100+
print(
101+
f"Iteration {iteration + 1}/{num_iterations}: {num_processes} processes all read the same content"
102+
)
103+
104+
finally:
105+
# Clean up temporary directory
106+
if os.path.exists(temp_dir):
107+
shutil.rmtree(temp_dir)
108+
109+
print(
110+
f"\nAll tests passed: {num_iterations} iterations × {num_processes} processes"
111+
)
112+
113+
114+
if __name__ == "__main__":
115+
# NOTE(Zihao): do not use pytest to run this test
116+
test_load_cubin_race_condition(100, 10)

0 commit comments

Comments
 (0)