@@ -379,6 +379,99 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
379379 # return F.linear(input, self.weight.to(dtype=input.dtype)) * se...
380380
381381
382+ def linear_forward_8da8w (
383+ x ,
384+ weight_int8 ,
385+ scales ,
386+ zeros ,
387+ out_features ,
388+ precision ,
389+ ):
390+ from torchao .quantization .utils import per_token_dynamic_quant
391+
392+ x = per_token_dynamic_quant (x )
393+ n_bit = 8
394+ quant_min = - (2 ** (n_bit - 1 ))
395+ quant_max = 2 ** (n_bit - 1 ) - 1
396+ w_dq = torch .ops .quantized_decomposed .dequantize_per_channel (
397+ weight_int8 ,
398+ scales ,
399+ zeros ,
400+ 0 ,
401+ quant_min ,
402+ quant_max ,
403+ torch .int8 ,
404+ out_dtype = precision ,
405+ )
406+ c = torch .nn .functional .linear (x , w_dq )
407+
408+ return c
409+
410+
411+ class Int8DynActInt8WeightLinear (torch .nn .Module ):
412+ __constants__ = ["in_features" , "out_features" ]
413+
414+ in_features : int
415+ out_features : int
416+ weight : torch .Tensor
417+
418+ """
419+ This module implements a dynamic quantized linear layer with int8 weight.
420+ Weights are per channel quantized. Parameters of importance
421+ precision: precision of input and output. e.g. torch.float32 means input
422+ activation is float32 and output is float32.
423+ """
424+
425+ def __init__ (
426+ self ,
427+ in_features : int ,
428+ out_features : int ,
429+ bias = True ,
430+ device = None ,
431+ dtype = None ,
432+ precision : torch .dtype = torch .float32 ,
433+ ) -> None :
434+ super ().__init__ ()
435+ self .in_features = in_features
436+ self .out_features = out_features
437+ assert not bias , "require bias=False"
438+ self .precision = precision
439+
440+ if dtype is not None :
441+ raise ValueError ("Please specify 'precision' instead of 'dtype'" )
442+
443+ # currently storing unpacked int8 weights
444+ self .register_buffer (
445+ "weight" ,
446+ torch .empty ((out_features , in_features ), dtype = torch .int8 ),
447+ )
448+ self .register_buffer (
449+ "scales" ,
450+ torch .empty (
451+ (out_features ),
452+ dtype = torch .float32 ,
453+ ),
454+ )
455+ self .register_buffer (
456+ "zeros" ,
457+ torch .empty (
458+ (out_features ),
459+ dtype = torch .float32 ,
460+ ),
461+ )
462+
463+ def forward (self , input : torch .Tensor ) -> torch .Tensor :
464+ input = input .to (self .precision )
465+ return linear_forward_8da8w (
466+ input ,
467+ self .weight ,
468+ self .scales ,
469+ self .zeros ,
470+ self .out_features ,
471+ self .precision ,
472+ )
473+
474+
382475#########################################################################
383476##### embedding table quantization ######
384477
0 commit comments