@@ -60,10 +60,7 @@ def scaled_e4m3_impl(
60
60
with torch .cuda .device (
61
61
None if inputs .device .index == torch .cuda .current_device () else inputs .device .index
62
62
):
63
- if amax is None :
64
- # This adds overhead; however this is not a common use case.
65
- amax = torch .tensor (448.0 , device = inputs .device , dtype = inputs .dtype )
66
- if amax .numel () == 1 :
63
+ if amax is None or amax .numel () == 1 :
67
64
outputs = cuda_ext_fp8 .fake_e4m3fy (inputs , amax )
68
65
else :
69
66
if amax .squeeze ().ndim > 1 :
@@ -556,136 +553,6 @@ def backward(ctx, grad_outputs):
556
553
return _fake_quant_backward_function (ctx , grad_outputs , num_args = 9 )
557
554
558
555
559
- class TensorQuantFunction (Function ):
560
- """A universal tensor quantization function.
561
-
562
- Take an input tensor, output an quantized tensor. The granularity of scale can be interpreted from the
563
- shape of amax.
564
- output_dtype indicates whether the quantized value will be stored in integer or float. The reason we want to store
565
- it in float is the pytorch function takes the quantized value may not accept integer input, e.g. Conv2D.
566
-
567
- It uses 2^num_bits -1 values instead of 2^num_bits. e.g., for num_bits=8, it uses [-127, 127] instead of [-128, 127]
568
- """
569
-
570
- @staticmethod
571
- @symbolic_helper .parse_args ("v" , "t" , "t" , "i" , "b" , "b" , "s" )
572
- def symbolic (
573
- g ,
574
- inputs ,
575
- amax ,
576
- bias = None ,
577
- num_bits = 8 ,
578
- unsigned = False ,
579
- narrow_range = True ,
580
- trt_high_precision_dtype = None ,
581
- ):
582
- """ONNX symbolic function."""
583
- from .export_onnx import export_int8
584
-
585
- return export_int8 (
586
- g , inputs , amax , num_bits , unsigned , narrow_range , trt_high_precision_dtype
587
- )
588
-
589
- @staticmethod
590
- def forward (
591
- ctx ,
592
- inputs ,
593
- amax ,
594
- bias = None ,
595
- num_bits = 8 ,
596
- unsigned = False ,
597
- narrow_range = True ,
598
- trt_high_precision_dtype = None ,
599
- ):
600
- """Forward method.
601
-
602
- Follow tensorflow convention, max value is passed in and used to decide scale, instead of inputting scale
603
- directly. Though inputting scale directly may be more natural to use.
604
-
605
- Args:
606
- ctx: A Context object to store tensors for backward.
607
- inputs: A Tensor of type float32.
608
- amax: A Tensor of type float32. Inputs will be quantized within range [-amax, amax]
609
- amax will be broadcasted to inputs tensor.
610
- num_bits: A integer used to calculate scaling factor, scale = (2^(num_bits-1) - 1) / max
611
- Effectively, it indicates how many integer bits is used to represent the value. Default 8.
612
- output_dtype: A type of Tensor. torch.int32 or torch.float32.
613
- unsigned: A boolean. Use unsigned integer range. E.g. [0, 255] for num_bits=8. Default False.
614
- narrow_range: A boolean. Use symmetric integer range for signed quantization
615
- E.g. [-127,127] instead of [-128,127] for num_bits=8. Default True.
616
-
617
- Returns:
618
- outputs: A Tensor of type output_dtype.
619
- scale: A Tensor of type float32. outputs / scale will dequantize outputs tensor.
620
-
621
- Raises:
622
- ValueError:
623
- """
624
- if bias is not None :
625
- inputs = inputs - bias
626
-
627
- ctx .save_for_backward (inputs , amax )
628
-
629
- outputs , scale = _tensor_quant (inputs , amax , num_bits , unsigned , narrow_range )
630
- # Check if scale overflows FP16
631
- if outputs .dtype == torch .half and scale .max () > 65504 :
632
- raise ValueError (f"scale is too large for FP16 with amax={ amax } " )
633
-
634
- if bias is not None :
635
- outputs = outputs + bias
636
-
637
- return outputs , scale .to (inputs .dtype )
638
-
639
- @staticmethod
640
- def backward (ctx , grad_outputs , grad_scale ):
641
- """Implements straight through estimation with clipping.
642
-
643
- For -amax <= input <= amax the gradient passes straight through, otherwise the gradient is zero.
644
-
645
- Args:
646
- ctx: A Context object with saved tensors from forward.
647
- grad_outputs: A tensor of gradient of outputs.
648
- grad_scale: A tensor of gradient of scale.
649
-
650
- Returns:
651
- grad_inputs: A tensor of gradient.
652
- """
653
- inputs , amax = ctx .saved_tensors
654
- zero = grad_outputs .new_zeros (1 ) # create a zero tensor with the same type and device
655
- grad_inputs = torch .where (inputs .abs () <= amax , grad_outputs , zero )
656
- return grad_inputs , None , None , None , None , None , None
657
-
658
-
659
- class LegacyFakeTensorQuantFunction (Function ):
660
- """Fake version of TensorQuantFunction.
661
-
662
- See comments of TensorQuantFunction, arguments are the same.
663
- """
664
-
665
- @staticmethod
666
- def forward (ctx , inputs , amax , bias , num_bits = 8 , unsigned = False , narrow_range = True ):
667
- """Forward method."""
668
- if bias is not None :
669
- inputs = inputs - bias
670
-
671
- ctx .save_for_backward (inputs , amax )
672
-
673
- outputs , scale = _tensor_quant (inputs , amax , num_bits , unsigned , narrow_range )
674
-
675
- if bias is not None :
676
- outputs = outputs + bias
677
-
678
- return outputs / scale .to (inputs .dtype )
679
-
680
- @staticmethod
681
- def backward (ctx , grad_outputs ):
682
- """Implements straight through estimation."""
683
- inputs , amax = ctx .saved_tensors
684
- zero = grad_outputs .new_zeros (1 )
685
- grad_inputs = torch .where (inputs .abs () <= amax , grad_outputs , zero )
686
- return grad_inputs , None , None , None , None , None
687
-
688
-
689
556
def _tensor_quant (inputs , amax , num_bits = 8 , unsigned = False , narrow_range = True ):
690
557
"""Shared function body between TensorQuantFunction and FakeTensorQuantFunction."""
691
558
# Fine scale, per channel scale will be handled by broadcasting, which could be tricky. Pop a warning.
@@ -694,10 +561,8 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
694
561
695
562
# Computation can be done in FP32 to prevent potential over flow.
696
563
input_dtype = inputs .dtype
697
- if inputs .dtype == torch .half :
698
- inputs = inputs .float ()
699
- if amax .dtype == torch .half :
700
- amax = amax .float ()
564
+ inputs = inputs .float ()
565
+ amax = amax .float ()
701
566
702
567
min_amax = amax .min ()
703
568
if min_amax < 0 :
@@ -724,72 +589,10 @@ def _tensor_quant(inputs, amax, num_bits=8, unsigned=False, narrow_range=True):
724
589
1.0 # Return 1 makes more sense for values quantized to 0 with amax=0
725
590
)
726
591
727
- if input_dtype == torch .half :
728
- outputs = outputs .half ()
729
-
592
+ outputs = outputs .to (input_dtype )
730
593
return outputs , scale
731
594
732
595
733
- class FakeAffineTensorQuantFunction (Function ):
734
- """Fake version of affine quantization.
735
-
736
- gemmlowp style scale+shift quantization. See more details in
737
- https://github.com/google/gemmlowp/blob/master/doc/quantization.md.
738
-
739
- We DO NOT recommend affine quantization on weights for performance reason. There might be value to affine quantize
740
- activation as it can be cancelled by bias and comes with no performance penalty. This functionality is only added
741
- for experimental purpose.
742
- """
743
-
744
- @staticmethod
745
- def forward (ctx , inputs , min_range , max_range , num_bits = 8 ):
746
- """As it will be only applied on activation with per tensor granularity, broadcast is not needed.
747
-
748
- Args:
749
- ctx: Pytorch convention.
750
- inputs: A Tensor of type float32.
751
- min_range: A float.
752
- max_range: A float.
753
- num_bits: An integer
754
-
755
- Returns:
756
- outputs: A Tensor of type output_dtype
757
- """
758
- ctx .save_for_backward (inputs , min_range , max_range )
759
-
760
- step_size = (max_range - min_range ) / (2.0 ** num_bits - 1 )
761
-
762
- min_bound = - (2.0 ** (num_bits - 1 ))
763
- max_bound = 2.0 ** (num_bits - 1 ) - 1
764
-
765
- quant_zero = torch .round (min_range / step_size ) - min_bound
766
- quantized = torch .round (inputs / step_size ) - quant_zero
767
- quantized = torch .clamp (quantized , min_bound , max_bound )
768
-
769
- outputs = (quantized + quant_zero ) * step_size
770
-
771
- return outputs
772
-
773
- @staticmethod
774
- def backward (ctx , grad_outputs ):
775
- """Implements straight through estimation with clipping.
776
-
777
- Args:
778
- ctx: Pytorch convention.
779
- grad_output: A tensor of gradient of outputs.
780
-
781
- Returns:
782
- grad_inputs: A tensor of gradient
783
- """
784
- inputs , min_range , max_range = ctx .saved_tensors
785
- zero = grad_outputs .new_zeros (1 )
786
- grad_inputs = torch .where ((inputs <= max_range ) * (inputs >= min_range ), grad_outputs , zero )
787
- return grad_inputs , None , None , None
788
-
789
-
790
- tensor_quant = TensorQuantFunction .apply
791
- legacy_fake_tensor_quant = LegacyFakeTensorQuantFunction .apply
792
596
fake_tensor_quant = FakeTensorQuantFunction .apply
793
- fake_affine_tensor_quant = FakeAffineTensorQuantFunction .apply
794
597
scaled_e4m3 = ScaledE4M3Function .apply
795
598
dynamic_block_quant = DynamicBlockQuantizationFunction .apply
0 commit comments