Skip to content

Commit 4960932

Browse files
Fix torch.compile issue for LLM.int8() with threshold=0 (#1581)
1 parent 90bbe14 commit 4960932

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

bitsandbytes/autograd/_functions.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,13 @@ def get_inverse_transform_indices(
8484
return permuted_tile_indices
8585

8686

87+
# torch.compiler.is_compiling() is available only in torch >= 2.3
88+
if hasattr(torch.compiler, "is_compiling"):
89+
_is_compiling = torch.compiler.is_compiling
90+
else:
91+
_is_compiling = torch._dynamo.is_compiling
92+
93+
8794
@deprecated(
8895
"This function is deprecated and will be removed in a future release.",
8996
category=FutureWarning,
@@ -174,7 +181,7 @@ def forward(
174181
input_shape = A.shape
175182

176183
# Cast A to fp16
177-
if A.dtype != torch.float16:
184+
if A.dtype != torch.float16 and not _is_compiling():
178185
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")
179186

180187
if len(A.shape) == 3:

0 commit comments

Comments
 (0)