Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 1db70be

Browse files
anmarquesbfineranspacemanidolBenjamin
authored
Quantization refactor (#663)
* Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Modified argument names for backwards compatibility. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Fixed default weights data type. * Style and quality fixes. * Removed unused method * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantization. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Modified argument names for backwards compatibility. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Fixed default weights data type. * Style and quality fixes. * Removed unused method * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantization. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Modified argument names for backwards compatibility. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Fixed default weights data type. * Style and quality fixes. * Removed unused method * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Modified argument names for backwards compatibility. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Fixed default weights data type. * Style and quality fixes. * Removed unused method * Removed testing files * Style and quality fixes. * Changed call to get_qat_qconfig to not specify symmetry and data type arguments for default case. * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantization. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Modified argument names for backwards compatibility. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Fixed default weights data type. * Style and quality fixes. * Removed unused method * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Removed output quantization from conv layers * Added _Add_ReLU module that enables QATWrapper for quantizaiton. * Removed quantization of output for linear and conv layers by default. Removed fusing of BN and ReLU by default. * Minor fixes. Style and quality fixes. * Added support to freezing bn stats. * Added mode argument to wrapping of train function in BNWrapper * Set BN fusing back as default. * Set BN fusing back as default. * Fixed custom freeze_bn_stats. * Temporary files for evaluating changes to graphs. * Added support to tensorrt flag. Moved the computation of quantization range to get_qat_config_config where it has full information about data type. * Added support to TensorRT quantization * Included check to account for when weight_qconfig_kwatgs is None. * Modified argument names for backwards compatibility. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Updated documentation to reflect changes. * Fixed default weights data type. * Style and quality fixes. * Removed unused method * Removed testing files * Style and quality fixes. * Changed call to get_qat_qconfig to not specify symmetry and data type arguments for default case. * Changed default number of activation and weight bits from None to 8. * Revert "Changed default number of activation and weight bits from None to 8." This reverts commit 95e966ed929fa3512331a73667d5ba2ac3d594b1. * Revert "Changed call to get_qat_qconfig to not specify symmetry and data type arguments for default case." This reverts commit a675813. * Lumped qconfig properties into a dataclass. * Lumped qconfig properties into a dataclass. * Lumped qconfig properties into a dataclass. * Resetting conv and linear activation flags to True. * Renamed class BNWrapper as _BNWrapper. * Added logging messages for when tensorrt forces overriding of configs. * Style and quality fixes. * ConvInteger quantization conversion for quant refactor (#644) * ConvInteger quantization conversion for quant refactor * [quantization-refactor] mark/propagate conv export mode (#672) * batch norm fold with existing bias param bug fix * Quantization Refactor Tests (#685) * rebase import fix * update manager serialization test cases for new quantization params Co-authored-by: Benjamin Fineran <[email protected]> Co-authored-by: spacemanidol <[email protected]> Co-authored-by: Benjamin <[email protected]>
1 parent a612e7b commit 1db70be

File tree

8 files changed

+794
-317
lines changed

8 files changed

+794
-317
lines changed

src/sparseml/pytorch/models/classification/resnet.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@
4747
except Exception:
4848
FloatFunctional = None
4949

50-
5150
__all__ = [
5251
"ResNetSectionSettings",
5352
"ResNet",
@@ -141,6 +140,28 @@ def required(in_channels: int, out_channels: int, stride: int) -> bool:
141140
return in_channels != out_channels or stride > 1
142141

143142

143+
class _AddReLU(Module):
144+
"""
145+
Wrapper for the FloatFunctional class that enables QATWrapper used to
146+
quantize the first input to the Add operation
147+
"""
148+
149+
def __init__(self, num_channels):
150+
super().__init__()
151+
if FloatFunctional:
152+
self.functional = FloatFunctional()
153+
self.wrap_qat = True
154+
self.qat_wrapper_kwargs = {"num_inputs": 1, "num_outputs": 0}
155+
else:
156+
self.functional = ReLU(num_channels=num_channels, inplace=True)
157+
158+
def forward(self, x, y):
159+
if isinstance(self.functional, FloatFunctional):
160+
return self.functional.add_relu(x, y)
161+
else:
162+
return self.functional(x + y)
163+
164+
144165
class _BasicBlock(Module):
145166
def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
146167
super().__init__()
@@ -164,11 +185,7 @@ def __init__(self, in_channels: int, out_channels: int, stride: int = 1):
164185
else None
165186
)
166187

167-
self.add_relu = (
168-
FloatFunctional()
169-
if FloatFunctional is not None
170-
else ReLU(num_channels=out_channels, inplace=True)
171-
)
188+
self.add_relu = _AddReLU(out_channels)
172189

173190
self.initialize()
174191

@@ -181,12 +198,7 @@ def forward(self, inp: Tensor):
181198
out = self.bn2(out)
182199

183200
identity_val = self.identity(inp) if self.identity is not None else inp
184-
185-
if isinstance(self.add_relu, FloatFunctional):
186-
out = self.add_relu.add_relu(out, identity_val)
187-
else:
188-
out += identity_val
189-
out = self.add_relu(out)
201+
out = self.add_relu(identity_val, out)
190202

191203
return out
192204

@@ -230,11 +242,7 @@ def __init__(
230242
else None
231243
)
232244

233-
self.add_relu = (
234-
FloatFunctional()
235-
if FloatFunctional is not None
236-
else ReLU(num_channels=out_channels, inplace=True)
237-
)
245+
self.add_relu = _AddReLU(out_channels)
238246

239247
self.initialize()
240248

@@ -252,11 +260,7 @@ def forward(self, inp: Tensor):
252260

253261
identity_val = self.identity(inp) if self.identity is not None else inp
254262

255-
if isinstance(self.add_relu, FloatFunctional):
256-
out = self.add_relu.add_relu(out, identity_val)
257-
else:
258-
out += identity_val
259-
out = self.add_relu(out)
263+
out = self.add_relu(identity_val, out)
260264

261265
return out
262266

0 commit comments

Comments
 (0)