|
38 | 38 | HistogramObserver, |
39 | 39 | MinMaxObserver, |
40 | 40 | MovingAverageMinMaxObserver, |
41 | | - MovingAveragePerChannelMinMaxObserver, |
42 | 41 | ObserverOrFakeQuantizeConstructor, |
43 | 42 | PerChannelMinMaxObserver, |
44 | 43 | PlaceholderObserver, |
@@ -95,24 +94,26 @@ def get_symmetric_quantization_config( |
95 | 94 | **extra_args, |
96 | 95 | ), |
97 | 96 | ) |
| 97 | + |
| 98 | + # Setup quantization config for weights |
98 | 99 | weight_qscheme = ( |
99 | 100 | torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric |
100 | 101 | ) |
101 | 102 | weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = ( |
102 | 103 | MinMaxObserver |
103 | 104 | ) |
| 105 | + # Determine the right observer/fake-quant constructor |
104 | 106 | if is_qat: |
105 | | - # TODO: qat + per channel? |
106 | | - weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize |
107 | | - elif is_per_channel: |
108 | | - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver |
| 107 | + # Set plain fake-quant with true min/max |
| 108 | + weight_observer_or_fake_quant_ctr = FakeQuantize |
| 109 | + else: |
| 110 | + # PTQ: set min/max observer |
| 111 | + weight_observer_or_fake_quant_ctr = ( |
| 112 | + PerChannelMinMaxObserver if is_per_channel else MinMaxObserver |
| 113 | + ) |
| 114 | + |
| 115 | + extra_args = {"eps": 2**-12} |
109 | 116 |
|
110 | | - extra_args: Dict[str, Any] = {"eps": 2**-12} |
111 | | - if is_qat: |
112 | | - if weight_qscheme == torch.per_tensor_symmetric: |
113 | | - extra_args["observer"] = MovingAverageMinMaxObserver |
114 | | - else: |
115 | | - extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] |
116 | 117 | weight_quantization_spec = QuantizationSpec( |
117 | 118 | dtype=torch.int8, |
118 | 119 | quant_min=weight_qmin, |
|
0 commit comments