@@ -17,11 +17,11 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
1717 shapeA = A .shape
1818 shapeB = B .shape
1919
20- torch ._check (A .dtype == torch .int8 , "B must be int8" )
21- torch ._check (B .dtype == torch .int8 , "A must be int8" )
22- torch ._check (A .ndim == 2 , "Only two dimensional matrices are supported for argument B" )
23- torch ._check (B .ndim in [2 , 3 ], "Only two or three dimensional matrices are supported for argument A" )
24- torch ._check (prod (shapeB ) > 0 , f"Input tensor dimensions need to be > 0: { shapeB } " )
20+ torch ._check (A .dtype == torch .int8 , lambda : "B must be int8" )
21+ torch ._check (B .dtype == torch .int8 , lambda : "A must be int8" )
22+ torch ._check (A .ndim == 2 , lambda : "Only two dimensional matrices are supported for argument B" )
23+ torch ._check (B .ndim in [2 , 3 ], lambda : "Only two or three dimensional matrices are supported for argument A" )
24+ torch ._check (prod (shapeB ) > 0 , lambda : f"Input tensor dimensions need to be > 0: { shapeB } " )
2525 torch ._check (out is None or out .dtype == dtype )
2626
2727 shapeC = (* shapeB [:- 1 ], shapeA [0 ])
@@ -34,7 +34,7 @@ def _(A: torch.Tensor, B: torch.Tensor, out: Optional[torch.Tensor] = None, dtyp
3434
3535 torch ._check (
3636 lda == ldb ,
37- f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = { shapeB } @ { shapeA } " ,
37+ lambda : f"int8_linear_matmul only supports B^T @ A. Inner dimensions do not match: B @ A = { shapeB } @ { shapeA } " ,
3838 )
3939
4040 # cuBLASLt does not support int8 matmul with inner dimensions that are not divisible by 4.
@@ -92,10 +92,12 @@ def _(
9292 out : Optional [torch .Tensor ] = None ,
9393 bias : Optional [torch .Tensor ] = None ,
9494) -> torch .Tensor :
95- torch ._check (A .dtype == torch .int32 , "A must be int32" )
95+ torch ._check (A .dtype == torch .int32 , lambda : f"A must be int32, got { A .dtype } " )
96+ torch ._check (row_stats .dtype == torch .float32 , lambda : f"row_stats must be float32, got { row_stats .dtype } " )
97+ torch ._check (col_stats .dtype == torch .float32 , lambda : f"col_stats must be float32, got { col_stats .dtype } " )
9698
9799 if bias is not None :
98- torch ._check (bias .dtype == torch .float16 )
100+ torch ._check (bias .dtype == torch .float16 , lambda : f"Only fp16 bias is supported, got { bias . dtype } " )
99101
100102 if out is None :
101103 out = torch .empty_like (A , dtype = torch .float16 )
@@ -118,7 +120,8 @@ def _(
118120
119121@register_kernel ("bitsandbytes::int8_vectorwise_quant" , "cuda" )
120122def _ (A : torch .Tensor , threshold = 0.0 ):
121- torch ._check (A .dtype == torch .float16 , "A must be float16" )
123+ torch ._check (A .dtype == torch .float16 , lambda : f"A must be float16, got { A .dtype } " )
124+ torch ._check (threshold >= 0.0 , lambda : "threshold must be non-negative" )
122125
123126 rows = prod (A .shape [:- 1 ])
124127 cols = A .shape [- 1 ]
@@ -205,12 +208,14 @@ def _get_col_absmax(
205208
206209@register_kernel ("bitsandbytes::quantize_blockwise" , "cuda" )
207210def _ (A : torch .Tensor , code : torch .Tensor , blocksize : int ) -> Tuple [torch .Tensor , torch .Tensor ]:
211+ torch ._check_is_size (blocksize )
208212 torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
213+ torch ._check (code .dtype == torch .float32 , lambda : f"code must be float32, got { code .dtype } " )
209214
210215 n = A .numel ()
211216 blocks = - (n // - blocksize )
212- absmax = torch .zeros ((blocks ,), device = A .device , dtype = torch .float32 )
213- out = torch .zeros_like (A , dtype = torch .uint8 )
217+ absmax = torch .empty ((blocks ,), device = A .device , dtype = torch .float32 )
218+ out = torch .empty_like (A , dtype = torch .uint8 )
214219
215220 with _cuda_device_of (A ):
216221 args = (
@@ -237,6 +242,10 @@ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> Tuple[torch.Tensor
237242@register_kernel ("bitsandbytes::dequantize_blockwise" , "cuda" )
238243def _ (A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype ) -> torch .Tensor :
239244 torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
245+ torch ._check (
246+ dtype in [torch .float16 , torch .bfloat16 , torch .float32 ],
247+ lambda : f"Blockwise dequantization only supports 16bit/32bit floating types, got { dtype } " ,
248+ )
240249
241250 out = torch .empty_like (A , dtype = dtype )
242251
@@ -257,8 +266,6 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
257266 lib .cdequantize_blockwise_bf16 (* args )
258267 elif dtype == torch .float32 :
259268 lib .cdequantize_blockwise_fp32 (* args )
260- else :
261- raise ValueError (f"Blockwise dequantization only supports 16/32-bit floats, but got { dtype } " )
262269
263270 return out
264271
@@ -269,6 +276,10 @@ def _(
269276) -> Tuple [torch .Tensor , torch .Tensor ]:
270277 torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
271278 torch ._check (quant_type in ["fp4" , "nf4" ])
279+ torch ._check (
280+ A .dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
281+ lambda : f"Blockwise 4bit quantization only supports 16/32-bit floats, but got { A .dtype } " ,
282+ )
272283
273284 n = A .numel ()
274285 blocks = - (n // - blocksize )
@@ -300,8 +311,6 @@ def _(
300311 lib .cquantize_blockwise_fp32_fp4 (* args )
301312 else :
302313 lib .cquantize_blockwise_fp32_nf4 (* args )
303- else :
304- raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { A .dtype } " )
305314
306315 return out , absmax
307316
@@ -312,6 +321,10 @@ def _(
312321) -> torch .Tensor :
313322 torch ._check (blocksize in [4096 , 2048 , 1024 , 512 , 256 , 128 , 64 ])
314323 torch ._check (quant_type in ["fp4" , "nf4" ])
324+ torch ._check (
325+ dtype in [torch .bfloat16 , torch .float16 , torch .float32 ],
326+ lambda : f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got { dtype } " ,
327+ )
315328
316329 out = torch .empty (shape , dtype = dtype , device = A .device )
317330 n = out .numel ()
@@ -344,7 +357,5 @@ def _(
344357 lib .cdequantize_blockwise_fp32_fp4 (* args )
345358 else :
346359 lib .cdequantize_blockwise_fp32_nf4 (* args )
347- else :
348- raise ValueError (f"Blockwise quantization only supports 16/32-bit floats, but got { out .dtype } " )
349360
350361 return out
0 commit comments