|
38 | 38 | HistogramObserver,
|
39 | 39 | MinMaxObserver,
|
40 | 40 | MovingAverageMinMaxObserver,
|
41 |
| - MovingAveragePerChannelMinMaxObserver, |
42 | 41 | ObserverOrFakeQuantizeConstructor,
|
43 | 42 | PerChannelMinMaxObserver,
|
44 | 43 | PlaceholderObserver,
|
@@ -95,24 +94,26 @@ def get_symmetric_quantization_config(
|
95 | 94 | **extra_args,
|
96 | 95 | ),
|
97 | 96 | )
|
| 97 | + |
| 98 | + # Setup quantization config for weights |
98 | 99 | weight_qscheme = (
|
99 | 100 | torch.per_channel_symmetric if is_per_channel else torch.per_tensor_symmetric
|
100 | 101 | )
|
101 | 102 | weight_observer_or_fake_quant_ctr: ObserverOrFakeQuantizeConstructor = (
|
102 | 103 | MinMaxObserver
|
103 | 104 | )
|
| 105 | + # Determine the right observer/fake-quant constructor |
104 | 106 | if is_qat:
|
105 |
| - # TODO: qat + per channel? |
106 |
| - weight_observer_or_fake_quant_ctr = FusedMovingAvgObsFakeQuantize |
107 |
| - elif is_per_channel: |
108 |
| - weight_observer_or_fake_quant_ctr = PerChannelMinMaxObserver |
| 107 | + # Set plain fake-quant with true min/max |
| 108 | + weight_observer_or_fake_quant_ctr = FakeQuantize |
| 109 | + else: |
| 110 | + # PTQ: set min/max observer |
| 111 | + weight_observer_or_fake_quant_ctr = ( |
| 112 | + PerChannelMinMaxObserver if is_per_channel else MinMaxObserver |
| 113 | + ) |
| 114 | + |
| 115 | + extra_args = {"eps": 2**-12} |
109 | 116 |
|
110 |
| - extra_args: Dict[str, Any] = {"eps": 2**-12} |
111 |
| - if is_qat: |
112 |
| - if weight_qscheme == torch.per_tensor_symmetric: |
113 |
| - extra_args["observer"] = MovingAverageMinMaxObserver |
114 |
| - else: |
115 |
| - extra_args["observer"] = MovingAveragePerChannelMinMaxObserver # type: ignore[dict-item] |
116 | 117 | weight_quantization_spec = QuantizationSpec(
|
117 | 118 | dtype=torch.int8,
|
118 | 119 | quant_min=weight_qmin,
|
|
0 commit comments