diff --git a/flashinfer/jit/cubin_loader.py b/flashinfer/jit/cubin_loader.py index 78615e86f9..286aefabf6 100644 --- a/flashinfer/jit/cubin_loader.py +++ b/flashinfer/jit/cubin_loader.py @@ -77,10 +77,13 @@ def download_file( with lock: logger.info(f"Acquired lock for {local_path}") + temp_path = f"{local_path}.tmp" + # Handle local file copy if os.path.exists(source): try: - shutil.copy(source, local_path) + shutil.copy(source, temp_path) + os.replace(temp_path, local_path) # Atomic rename logger.info(f"File copied successfully: {local_path}") return True except Exception as e: @@ -93,9 +96,12 @@ def download_file( response = session.get(source, timeout=timeout) response.raise_for_status() - with open(local_path, "wb") as file: + with open(temp_path, "wb") as file: file.write(response.content) + # Atomic rename to prevent readers from seeing partial writes + os.replace(temp_path, local_path) + logger.info( f"File downloaded successfully: {source} -> {local_path}" ) diff --git a/tests/utils/test_load_cubin_compile_race_condition.py b/tests/utils/test_load_cubin_compile_race_condition.py new file mode 100644 index 0000000000..29b2165a54 --- /dev/null +++ b/tests/utils/test_load_cubin_compile_race_condition.py @@ -0,0 +1,116 @@ +""" +Copyright (c) 2025 by FlashInfer team. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import os +import tempfile +from pathlib import Path +from multiprocessing import Pool + + +def worker_process(temp_dir): + """ + Worker function that each process executes. + + Each process will: + 1. Set FLASHINFER_CUBIN_DIR environment variable + 2. Import and call get_cubin with the same target file + 3. Read the file from FLASHINFER_CUBIN_DIR + 4. Return the file content + """ + # Set environment variable for this process + os.environ["FLASHINFER_CUBIN_DIR"] = temp_dir + + # Import here to ensure FLASHINFER_CUBIN_DIR is set before module loads + from flashinfer.artifacts import ArtifactPath, MetaInfoHash + from flashinfer.jit.cubin_loader import get_cubin + + # Define the target file - same for all processes + include_path = f"{ArtifactPath.TRTLLM_GEN_BMM}/include" + header_name = "flashinferMetaInfo" + + # Use get_cubin to get "flashinferMetaInfo.h" + # Note: all processes target the same file name + metainfo = get_cubin(f"{include_path}/{header_name}.h", MetaInfoHash.TRTLLM_GEN_BMM) # noqa: F841 + + # Read the file from FLASHINFER_CUBIN_DIR + # NOTE(Zihao): instead of using metainfo, we directly read from the file path, + # that aligns with how we compile the kernel. + file_path = Path(temp_dir) / include_path / f"{header_name}.h" + with open(file_path, "rb") as f: + content = f.read() + + return content + + +def test_load_cubin_race_condition(num_iterations, num_processes): + """ + Test race condition when multiple processes concurrently call get_cubin + for the same file. + + Test steps: + 1. Set up a temporary FLASHINFER_CUBIN_DIR + 2. Launch multiple processes + 3. Each process calls get_cubin for the same target file + 4. Each process reads the downloaded file + 5. Verify all processes read the same content + 6. Repeat multiple times to increase chance of detecting race conditions + + Args: + num_iterations: Number of times to repeat the test + num_processes: Number of concurrent processes per iteration + """ + import shutil + + for iteration in range(num_iterations): + # Create a temporary directory for FLASHINFER_CUBIN_DIR + temp_dir = tempfile.mkdtemp(prefix="flashinfer_test_cubin_") + + try: + # Launch multiple processes concurrently + with Pool(processes=num_processes) as pool: + results = pool.map(worker_process, [temp_dir] * num_processes) + + # Verify all processes read the same content + assert len(results) == num_processes, ( + f"Expected {num_processes} results, got {len(results)}" + ) + + # All results should be identical + first_content = results[0] + for i, content in enumerate(results): + assert content == first_content, ( + f"Iteration {iteration + 1}/{num_iterations}, Process {i} read different content. " + f"Expected length {len(first_content)}, got {len(content)}" + ) + + if (iteration + 1) % 10 == 0 or iteration == 0: + print( + f"Iteration {iteration + 1}/{num_iterations}: {num_processes} processes all read the same content" + ) + + finally: + # Clean up temporary directory + if os.path.exists(temp_dir): + shutil.rmtree(temp_dir) + + print( + f"\nAll tests passed: {num_iterations} iterations × {num_processes} processes" + ) + + +if __name__ == "__main__": + # NOTE(Zihao): do not use pytest to run this test + test_load_cubin_race_condition(100, 10)