Skip to content

Commit b852d5f

Browse files
committed
Fix bitsandbytes kernel registration conflict on ROCm systems (#362)
Add ensure_bitsandbytes_safe() shim to handle broken/partial bitsandbytes installations that cause PyTorch kernel registration conflicts when diffusers attempts to re-import the module. On ROCm systems without proper binaries, bitsandbytes registers kernels during import then fails. When diffusers later imports it, the duplicate registration causes: 'RuntimeError: already a kernel registered...int8_mm_dequant' The shim pre-tests bitsandbytes import and stubs it only if broken, allowing working installations to function normally for other nodes.
1 parent 7ba37c0 commit b852d5f

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

src/optimization/compatibility.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,38 @@ def __getattr__(self, name):
8181
sys.modules['xformers._C_flashattention'] = stub
8282

8383

84+
def ensure_bitsandbytes_safe():
85+
"""
86+
Pre-test bitsandbytes; stub if broken to prevent import conflicts.
87+
88+
On some systems (e.g., ROCm without proper binaries), bitsandbytes registers
89+
PyTorch kernels during import then fails. If another node already triggered
90+
this partial load, re-importing causes kernel registration conflicts.
91+
92+
This shim catches such failures and stubs the module so diffusers can load
93+
gracefully without bitsandbytes quantization support.
94+
"""
95+
if 'bitsandbytes' in sys.modules:
96+
return # Already loaded or stubbed
97+
98+
try:
99+
import bitsandbytes
100+
# Success - bitsandbytes works, other nodes can use it
101+
except (ImportError, OSError, RuntimeError):
102+
# Installation broken or not present - create stub
103+
stub = types.ModuleType('bitsandbytes')
104+
stub.__spec__ = importlib.machinery.ModuleSpec('bitsandbytes', None)
105+
stub.__file__ = None
106+
stub.__path__ = []
107+
stub.__version__ = "0.0.0"
108+
sys.modules['bitsandbytes'] = stub
109+
110+
84111
# Run all shims immediately on import, before torch/diffusers
85112
ensure_triton_compat()
86113
ensure_flash_attn_safe()
87114
ensure_xformers_flash_compat()
115+
ensure_bitsandbytes_safe()
88116

89117

90118
import torch

0 commit comments

Comments
 (0)