| 
 | 1 | +import torch  | 
 | 2 | +from torch.ao.quantization.observer import UniformQuantizationObserverBase  | 
 | 3 | + | 
 | 4 | + | 
 | 5 | +# TODO move to torch/ao/quantization/observer.py.  | 
 | 6 | +class PerChannelParamObserver(UniformQuantizationObserverBase):  | 
 | 7 | +    """  | 
 | 8 | +    Minimize quantization loss caused by outlier via linear search. More details can be found at https://arxiv.org/pdf/2209.13325  | 
 | 9 | +    """  | 
 | 10 | + | 
 | 11 | +    def __init__(  | 
 | 12 | +        self,  | 
 | 13 | +        ch_axis=0,  | 
 | 14 | +        use_mse=True,  | 
 | 15 | +        steps=100,  | 
 | 16 | +        dtype=torch.int8,  | 
 | 17 | +        qscheme=torch.per_channel_symmetric,  | 
 | 18 | +        reduce_range=False,  | 
 | 19 | +        quant_min=None,  | 
 | 20 | +        quant_max=None,  | 
 | 21 | +        factory_kwargs=None,  | 
 | 22 | +        eps=torch.finfo(torch.float32).eps,  # noqa: B008  | 
 | 23 | +        is_dynamic=False,  | 
 | 24 | +        **kwargs,  | 
 | 25 | +    ) -> None:  | 
 | 26 | +        super().__init__(  | 
 | 27 | +            dtype=dtype,  | 
 | 28 | +            qscheme=qscheme,  | 
 | 29 | +            reduce_range=reduce_range,  | 
 | 30 | +            quant_min=quant_min,  | 
 | 31 | +            quant_max=quant_max,  | 
 | 32 | +            factory_kwargs=factory_kwargs,  | 
 | 33 | +            eps=eps,  | 
 | 34 | +            is_dynamic=is_dynamic,  | 
 | 35 | +            **kwargs,  | 
 | 36 | +        )  | 
 | 37 | + | 
 | 38 | +        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)  | 
 | 39 | +        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))  | 
 | 40 | +        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))  | 
 | 41 | +        self.ch_axis = ch_axis  | 
 | 42 | +        self.use_mse = use_mse  | 
 | 43 | +        self.steps = steps  | 
 | 44 | +        self.calibrated = False  | 
 | 45 | + | 
 | 46 | +    def to_ch_axis(self, x):  | 
 | 47 | +        axis_order = list(range(len(x.size())))  | 
 | 48 | +        axis_order[self.ch_axis], axis_order[0] = 0, self.ch_axis  | 
 | 49 | +        return torch.flatten(x.permute(axis_order), start_dim=1)  | 
 | 50 | + | 
 | 51 | +    def mse(self, pred, expect):  | 
 | 52 | +        loss = (pred - expect).abs().pow(2)  | 
 | 53 | +        return self.to_ch_axis(loss).mean(1)  | 
 | 54 | + | 
 | 55 | +    def cosine(self, pred, expect):  | 
 | 56 | +        target = torch.ones(pred.shape[self.ch_axis])  | 
 | 57 | +        pred_n = self.to_ch_axis(pred).reshape(pred.shape[0], -1)  | 
 | 58 | +        expect_n = self.to_ch_axis(expect).reshape(expect.shape[0], -1)  | 
 | 59 | +        return torch.nn.CosineEmbeddingLoss()(pred_n, expect_n, target)  | 
 | 60 | + | 
 | 61 | +    def loss_fn(self, x, new_min, new_max):  | 
 | 62 | +        scale, offset = self._calculate_qparams(new_min, new_max)  | 
 | 63 | +        x_q = torch.fake_quantize_per_channel_affine(  | 
 | 64 | +            x,  | 
 | 65 | +            scale.data,  | 
 | 66 | +            offset.data.int(),  | 
 | 67 | +            self.ch_axis,  | 
 | 68 | +            self.quant_min,  | 
 | 69 | +            self.quant_max,  | 
 | 70 | +        )  | 
 | 71 | +        return self.mse(x_q, x) if self.use_mse else self.cosine(x_q, x)  | 
 | 72 | + | 
 | 73 | +    def line_search(self, x):  | 
 | 74 | +        x_min, x_max = torch.aminmax(self.to_ch_axis(x), dim=1)  | 
 | 75 | +        x_range = torch.max(x_min.abs(), x_max)  | 
 | 76 | +        optimal_loss = torch.zeros_like(x_min) + 1e9  | 
 | 77 | + | 
 | 78 | +        # check which clip range could produce smallest loss  | 
 | 79 | +        for i in range(1, self.steps + 1):  | 
 | 80 | +            thres = x_range / self.steps * i  | 
 | 81 | +            current_loss = self.loss_fn(x, -thres, thres)  | 
 | 82 | +            x_min = torch.where(current_loss < optimal_loss, -thres, x_min)  | 
 | 83 | +            x_max = torch.where(current_loss < optimal_loss, thres, x_max)  | 
 | 84 | +            optimal_loss = torch.min(current_loss, optimal_loss)  | 
 | 85 | + | 
 | 86 | +        return x_min, x_max  | 
 | 87 | + | 
 | 88 | +    def forward(self, x_orig):  | 
 | 89 | +        # since params are static, one calibration is enough  | 
 | 90 | +        if not self.calibrated:  | 
 | 91 | +            x = x_orig.detach().to(self.min_val.dtype)  | 
 | 92 | +            self.min_val, self.max_val = self.line_search(x)  | 
 | 93 | +            self.calibrated = True  | 
 | 94 | + | 
 | 95 | +        # return fake-quant result for saturating outliers  | 
 | 96 | +        scale, zero_point = self._calculate_qparams(self.min_val, self.max_val)  | 
 | 97 | +        return torch.fake_quantize_per_channel_affine(  | 
 | 98 | +            x_orig,  | 
 | 99 | +            scale.data,  | 
 | 100 | +            zero_point.data.int(),  | 
 | 101 | +            self.ch_axis,  | 
 | 102 | +            self.quant_min,  | 
 | 103 | +            self.quant_max,  | 
 | 104 | +        )  | 
 | 105 | + | 
 | 106 | +    @torch.jit.export  | 
 | 107 | +    def calculate_qparams(self):  | 
 | 108 | +        return self._calculate_qparams(self.min_val, self.max_val)  | 
0 commit comments