Skip to content

Commit 04e2089

Browse files
authored
Fix parameter name in error message
1 parent 18e827d commit 04e2089

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

bitsandbytes/nn/modules.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,10 +218,10 @@ def set_compute_type(self, x):
218218
if self.compute_dtype == torch.float32 and (x.numel() == x.shape[-1]):
219219
# single batch inference with input torch.float16 and compute_dtype float32 -> slow inference when it could be fast
220220
# warn the user about this
221-
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference.')
221+
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference.')
222222
warnings.filterwarnings('ignore', message='.*inference.')
223223
if self.compute_dtype == torch.float32 and (x.numel() != x.shape[-1]):
224-
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_type=torch.float32 (default). This will lead to slow inference or training speed.')
224+
warnings.warn(f'Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 (default). This will lead to slow inference or training speed.')
225225
warnings.filterwarnings('ignore', message='.*inference or training')
226226

227227

0 commit comments

Comments
 (0)