@@ -44,11 +44,16 @@ def _get_cubin_dir():
44
44
import flashinfer_cubin
45
45
46
46
flashinfer_cubin_version = flashinfer_cubin .__version__
47
- if flashinfer_version != flashinfer_cubin_version :
47
+ # Allow bypassing version check with environment variable
48
+ if (
49
+ not os .getenv ("FLASHINFER_DISABLE_VERSION_CHECK" )
50
+ and flashinfer_version != flashinfer_cubin_version
51
+ ):
48
52
raise RuntimeError (
49
53
f"flashinfer-cubin version ({ flashinfer_cubin_version } ) does not match "
50
54
f"flashinfer version ({ flashinfer_version } ). "
51
- "Please install the same version of both packages."
55
+ "Please install the same version of both packages. "
56
+ "Set FLASHINFER_DISABLE_VERSION_CHECK=1 to bypass this check."
52
57
)
53
58
54
59
return pathlib .Path (flashinfer_cubin .get_cubin_dir ())
@@ -78,11 +83,15 @@ def _get_aot_dir():
78
83
flashinfer_jit_cache_version = flashinfer_jit_cache .__version__
79
84
# NOTE(Zihao): we don't use exact version match here because the version of flashinfer-jit-cache
80
85
# contains the CUDA version suffix: e.g. 0.3.1+cu129.
81
- if not flashinfer_jit_cache_version .startswith (flashinfer_version ):
86
+ # Allow bypassing version check with environment variable
87
+ if not os .getenv (
88
+ "FLASHINFER_DISABLE_VERSION_CHECK"
89
+ ) and not flashinfer_jit_cache_version .startswith (flashinfer_version ):
82
90
raise RuntimeError (
83
91
f"flashinfer-jit-cache version ({ flashinfer_jit_cache_version } ) does not match "
84
92
f"flashinfer version ({ flashinfer_version } ). "
85
- "Please install the same version of both packages."
93
+ "Please install the same version of both packages. "
94
+ "Set FLASHINFER_DISABLE_VERSION_CHECK=1 to bypass this check."
86
95
)
87
96
88
97
return pathlib .Path (flashinfer_jit_cache .get_jit_cache_dir ())
0 commit comments