Skip to content

Commit ba2b0aa

Browse files
authored
fix factory_kwargs param bug and add quant_min/quant_max compatibility for torch 1.10 (#47)
1 parent f65f407 commit ba2b0aa

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

mqbench/observer.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import math
22
from functools import partial
33
from typing import Tuple
4-
4+
from copy import deepcopy
55
import torch
66
from torch.quantization.observer import _ObserverBase
77

@@ -28,8 +28,13 @@ class ObserverBase(_ObserverBase):
2828
def __init__(self, dtype=torch.quint8, qscheme=torch.per_tensor_affine,
2929
reduce_range=False, quant_min=None, quant_max=None, ch_axis=-1, pot_scale=False,
3030
factory_kwargs=None):
31+
factory_kwargs = deepcopy(factory_kwargs)
3132
self.not_calc_quant_min_max = factory_kwargs.pop('not_calc_quant_min_max', False) if isinstance(factory_kwargs, dict) else False
3233
super(ObserverBase, self).__init__(dtype, qscheme, reduce_range, quant_min, quant_max)
34+
# for compatibility with 1.10, prevent the value of self.quant_min,self.quant_max being modified
35+
self.quant_min = quant_min
36+
self.quant_max = quant_max
37+
self.quant_min, self.quant_max = self._calculate_qmin_qmax()
3338
self.ch_axis = ch_axis
3439
self.pot_scale = pot_scale
3540
self.register_buffer("min_val", torch.tensor(float("inf")))

0 commit comments

Comments
 (0)