@@ -44,30 +44,42 @@ def _(
4444 bias : Optional [torch .Tensor ] = None ,
4545 dtype = torch .float16 ,
4646) -> torch .Tensor :
47- out_i32 = torch .ops .bitsandbytes .int8_linear_matmul (A , B )
48- out = torch .ops .bitsandbytes .int8_mm_dequant (out_i32 , row_stats , col_stats , dtype = dtype , bias = bias )
47+ out_i32 = torch .ops .bitsandbytes .int8_linear_matmul . default (A , B )
48+ out = torch .ops .bitsandbytes .int8_mm_dequant . default (out_i32 , row_stats , col_stats , dtype = dtype , bias = bias )
4949 return out
5050
5151
52- # Define op
53- # TODO: mutable output arg as alias of return can be challenging;
54- # consider a separate op without aliased return:
55- # int8_linear_matmul_out(
56- # Tensor A, Tensor B, Tensor out, ScalarType dtype=int32
57- # ) -> ()
58- # return () instead of `None` for compatibility, see here: https://github.com/pytorch/pytorch/issues/125044
5952torch .library .define (
6053 "bitsandbytes::int8_linear_matmul" ,
61- "(Tensor A, Tensor B, Tensor? out=None, ScalarType dtype=int32 ) -> Tensor" ,
54+ "(Tensor A, Tensor B) -> Tensor" ,
6255)
6356
6457
6558@register_fake ("bitsandbytes::int8_linear_matmul" )
66- def _ (A : torch .Tensor , B : torch .Tensor , out : Optional [torch .Tensor ] = None , dtype = torch .int32 ):
59+ def _ (A : torch .Tensor , B : torch .Tensor ):
60+ torch ._check (A .dtype == torch .int8 , lambda : "A must be int8" )
61+ torch ._check (B .dtype == torch .int8 , lambda : "B must be int8" )
6762 shapeC = (* A .shape [:- 1 ], B .shape [0 ])
68- if out is None :
69- return torch .empty (shapeC , device = A .device , dtype = dtype )
70- return out
63+ return torch .empty (shapeC , device = A .device , dtype = torch .int32 )
64+
65+
66+ # More info on `out` overloads:
67+ # https://github.com/pytorch/pytorch/issues/125044
68+ torch .library .define (
69+ "bitsandbytes::int8_linear_matmul.out" ,
70+ "(Tensor A, Tensor B, Tensor! out) -> ()" ,
71+ )
72+
73+
74+ @register_fake ("bitsandbytes::int8_linear_matmul.out" )
75+ def _ (A : torch .Tensor , B : torch .Tensor , out : torch .Tensor ):
76+ shapeC = (* A .shape [:- 1 ], B .shape [0 ])
77+
78+ torch ._check (A .dtype == torch .int8 , lambda : "A must be int8" )
79+ torch ._check (B .dtype == torch .int8 , lambda : "B must be int8" )
80+ torch ._check (out .shape == shapeC , lambda : f"Expected out.shape == { shapeC } , got { out .shape } " )
81+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
82+ torch ._check (out .dtype == torch .int32 , lambda : f"Expected out.dtype == int32, got { out .dtype } " )
7183
7284
7385torch .library .define (
@@ -107,7 +119,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):
107119
108120torch .library .define (
109121 "bitsandbytes::int8_mm_dequant" ,
110- "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? out=None, Tensor? bias=None) -> Tensor" ,
122+ "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor" ,
111123)
112124
113125
@@ -117,7 +129,6 @@ def _(
117129 row_stats : torch .Tensor ,
118130 col_stats : torch .Tensor ,
119131 dtype = torch .float16 ,
120- out : Optional [torch .Tensor ] = None ,
121132 bias : Optional [torch .Tensor ] = None ,
122133) -> torch .Tensor :
123134 torch ._check (A .dtype == torch .int32 , lambda : "A must be int32" )
@@ -126,17 +137,13 @@ def _(
126137
127138torch .library .define (
128139 "bitsandbytes::int8_double_quant" ,
129- "(Tensor A, Tensor? col_stats, Tensor? row_stats, Tensor? out_col, Tensor? out_row, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)" ,
140+ "(Tensor A, float threshold=0.0) -> (Tensor, Tensor, Tensor, Tensor, Tensor?)" ,
130141)
131142
132143
133144@register_fake ("bitsandbytes::int8_double_quant" )
134145def _ (
135146 A : torch .Tensor ,
136- col_stats : Optional [torch .Tensor ] = None ,
137- row_stats : Optional [torch .Tensor ] = None ,
138- out_col : Optional [torch .Tensor ] = None ,
139- out_row : Optional [torch .Tensor ] = None ,
140147 threshold = 0.0 ,
141148) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , Optional [torch .Tensor ]]:
142149 out_row = torch .empty_like (A , dtype = torch .int8 )
@@ -156,12 +163,39 @@ def _(
156163
157164@register_fake ("bitsandbytes::dequantize_4bit" )
158165def _ (
159- A : torch .Tensor , absmax : torch .Tensor , blocksize : int , quant_type : str , shape : Sequence [int ], dtype : torch .dtype
166+ A : torch .Tensor ,
167+ absmax : torch .Tensor ,
168+ blocksize : int ,
169+ quant_type : str ,
170+ shape : Sequence [int ],
171+ dtype : torch .dtype ,
160172) -> torch .Tensor :
161173 torch ._check_is_size (blocksize )
162174 return torch .empty (shape , dtype = dtype , device = A .device )
163175
164176
177+ torch .library .define (
178+ "bitsandbytes::dequantize_4bit.out" ,
179+ "(Tensor A, Tensor absmax, int blocksize, str quant_type, int[] shape, ScalarType dtype, Tensor! out) -> ()" ,
180+ )
181+
182+
183+ @register_fake ("bitsandbytes::dequantize_4bit.out" )
184+ def _ (
185+ A : torch .Tensor ,
186+ absmax : torch .Tensor ,
187+ blocksize : int ,
188+ quant_type : str ,
189+ shape : Sequence [int ],
190+ dtype : torch .dtype ,
191+ out : torch .Tensor ,
192+ ) -> None :
193+ torch ._check_is_size (blocksize )
194+ torch ._check (out .shape == shape , lambda : f"Expected out.shape == { shape } , got { out .shape } " )
195+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
196+ torch ._check (out .dtype == dtype , lambda : f"Expected out.dtype == { dtype } , got { out .dtype } " )
197+
198+
165199torch .library .define (
166200 "bitsandbytes::quantize_4bit" ,
167201 "(Tensor A, int blocksize, str quant_type, ScalarType quant_storage) -> (Tensor, Tensor)" ,
@@ -194,6 +228,23 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int,
194228 return torch .empty_like (A , dtype = dtype )
195229
196230
231+ torch .library .define (
232+ "bitsandbytes::dequantize_blockwise.out" ,
233+ "(Tensor A, Tensor absmax, Tensor code, int blocksize, ScalarType dtype, Tensor! out) -> ()" ,
234+ )
235+
236+
237+ @register_fake ("bitsandbytes::dequantize_blockwise.out" )
238+ def _ (
239+ A : torch .Tensor , absmax : torch .Tensor , code : torch .Tensor , blocksize : int , dtype : torch .dtype , out : torch .Tensor
240+ ):
241+ torch ._check_is_size (blocksize )
242+ torch ._check (A .dtype == torch .uint8 , lambda : f"A must be uint8, got { A .dtype } " )
243+ torch ._check (out .shape == A .shape , lambda : f"Expected out.shape == { A .shape } , got { out .shape } " )
244+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
245+ torch ._check (out .dtype == dtype , lambda : f"Expected out.dtype == { dtype } , got { out .dtype } " )
246+
247+
197248torch .library .define ("bitsandbytes::quantize_blockwise" , "(Tensor A, Tensor code, int blocksize) -> (Tensor, Tensor)" )
198249
199250
@@ -229,3 +280,37 @@ def _(
229280 )
230281 shape = (* A .shape [:- 1 ], shapeB [0 ])
231282 return torch .empty (shape , device = A .device , dtype = A .dtype )
283+
284+
285+ torch .library .define (
286+ "bitsandbytes::gemv_4bit.out" ,
287+ "(Tensor A, Tensor B, int[] shapeB, Tensor absmax, Tensor code, int blocksize, Tensor! out) -> ()" ,
288+ )
289+
290+
291+ @register_fake ("bitsandbytes::gemv_4bit.out" )
292+ def _ (
293+ A : torch .Tensor ,
294+ B : torch .Tensor ,
295+ shapeB : Sequence [int ],
296+ absmax : torch .Tensor ,
297+ code : torch .Tensor ,
298+ blocksize : int ,
299+ out : torch .Tensor ,
300+ ) -> None :
301+ torch ._check_is_size (blocksize )
302+ torch ._check (A .numel () == A .size (- 1 ), lambda : f"A must be a vector with leading dimensions of 1, got { A .shape } " )
303+ torch ._check (
304+ A .dtype in [torch .float16 , torch .bfloat16 , torch .float32 ],
305+ lambda : f"A must be float16, bfloat16, or float32, got { A .dtype } " ,
306+ )
307+ torch ._check (
308+ B .dtype in [torch .uint8 , torch .bfloat16 , torch .float16 , torch .float32 ],
309+ lambda : f"B must be backed by storage of type uint8, bfloat16, float16, or float32, got { B .dtype } " ,
310+ )
311+ torch ._check (
312+ out .shape == (* A .shape [:- 1 ], shapeB [0 ]),
313+ lambda : f"Expected out.shape == { (* A .shape [:- 1 ], shapeB [0 ])} , got { out .shape } " ,
314+ )
315+ torch ._check (out .device == A .device , lambda : f"Expected out.device == { A .device } , got { out .device } " )
316+ torch ._check (out .dtype == A .dtype , lambda : f"Expected out.dtype == { A .dtype } , got { out .dtype } " )
0 commit comments