4141 configure_module_qat_wrappers ,
4242 fuse_module_conv_bn_relus ,
4343 get_qat_qconfig ,
44+ prepare_embeddings_qat ,
4445)
4546
4647
@@ -80,6 +81,10 @@ class QuantizationModifier(ScheduledModifier):
8081 exception. For compatibility with YAML serialization only.
8182 :param model_fuse_fn_kwargs: dictionary of keyword argument values to be passed
8283 to the model fusing function
84+ :param quantize_embeddings: if True, will perform QAT on torch.nn.Embedding layers
85+ using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
86+ quantize embedding weights. Default is True. Models without embedding layers
87+ will be unaffected
8388 """
8489
8590 def __init__ (
@@ -91,6 +96,7 @@ def __init__(
9196 freeze_bn_stats_epoch : Union [float , None ] = None ,
9297 end_epoch : float = - 1 ,
9398 model_fuse_fn_kwargs : Dict [str , Any ] = None ,
99+ quantize_embeddings : bool = True ,
94100 ):
95101 if torch_quantization is None or torch_intrinsic is None :
96102 raise RuntimeError (
@@ -112,6 +118,7 @@ def __init__(
112118 self ._model_fuse_fn_kwargs = model_fuse_fn_kwargs or {}
113119 self ._disable_quantization_observer_epoch = disable_quantization_observer_epoch
114120 self ._freeze_bn_stats_epoch = freeze_bn_stats_epoch
121+ self ._quantize_embeddings = quantize_embeddings
115122
116123 self ._modules_to_quantize = None
117124 self ._qat_enabled = False
@@ -140,7 +147,7 @@ def submodules(self) -> Union[List[str], None]:
140147 def submodules (self , value : Union [List [str ], None ]):
141148 """
142149 :params value: List of submodule names to perform QAT on. Set None to quantize
143- entire model
150+ entire model
144151 """
145152 self ._submodules = value
146153 if isinstance (self ._submodules , list ):
@@ -151,18 +158,18 @@ def submodules(self, value: Union[List[str], None]):
151158 def model_fuse_fn_name (self ) -> Union [str , None ]:
152159 """
153160 :return: Name of model function to fuse the model in place prior
154- to performing QAT. None to uses the default function
155- `sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
161+ to performing QAT. None to uses the default function
162+ `sparseml.pytorch.utils.fuse_module_conv_bn_relus`.
156163 """
157164 return self ._model_fuse_fn_name
158165
159166 @model_fuse_fn_name .setter
160167 def model_fuse_fn_name (self , value : Union [str , None ]):
161168 """
162169 :params value: Name of model function to fuse the model in place prior
163- to performing QAT. Set None to use the default function
164- `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
165- to skip module fusing.
170+ to performing QAT. Set None to use the default function
171+ `sparseml.pytorch.utils.fuse_module_conv_bn_relus`. Set as 'no_fuse'
172+ to skip module fusing.
166173 """
167174 self ._model_fuse_fn_name = value
168175 if (
@@ -176,17 +183,17 @@ def model_fuse_fn_name(self, value: Union[str, None]):
176183 def disable_quantization_observer_epoch (self ) -> Union [float , None ]:
177184 """
178185 :return: Epoch to disable updates to the module's
179- quantization observers. After this point, quantized weights and zero points will
180- not be updated. When None, observers never disabled during QAT
186+ quantization observers. After this point, quantized weights and zero points
187+ will not be updated. When None, observers never disabled during QAT
181188 """
182189 return self ._disable_quantization_observer_epoch
183190
184191 @disable_quantization_observer_epoch .setter
185192 def disable_quantization_observer_epoch (self , value : Union [float , None ]):
186193 """
187194 :params value: Epoch to disable updates to the module's
188- quantization observers. After this point, quantized weights and zero points will
189- not be updated. Set None to not disable observers during QAT
195+ quantization observers. After this point, quantized weights and zero points
196+ will not be updated. Set None to not disable observers during QAT
190197 """
191198 self ._disable_quantization_observer_epoch = value
192199 self ._validate_params ()
@@ -195,19 +202,37 @@ def disable_quantization_observer_epoch(self, value: Union[float, None]):
195202 def freeze_bn_stats_epoch (self ) -> Union [float , None ]:
196203 """
197204 :return: Epoch to stop the tracking of batch norm stats. When
198- None, batch norm stats are track for all of training
205+ None, batch norm stats are track for all of training
199206 """
200207 return self ._freeze_bn_stats_epoch
201208
202209 @freeze_bn_stats_epoch .setter
203210 def freeze_bn_stats_epoch (self , value : Union [float , None ]):
204211 """
205212 :params value: Epoch to stop the tracking of batch norm stats. Set
206- None to not stop tracking batch norm stats during QAT
213+ None to not stop tracking batch norm stats during QAT
207214 """
208215 self ._freeze_bn_stats_epoch = value
209216 self ._validate_params ()
210217
218+ @ModifierProp ()
219+ def quantize_embeddings (self ) -> bool :
220+ """
221+ :return: if True, will perform QAT on torch.nn.Embedding layers
222+ using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
223+ quantize embedding weights
224+ """
225+ return self ._freeze_bn_stats_epoch
226+
227+ @quantize_embeddings .setter
228+ def quantize_embeddings (self , value : bool ):
229+ """
230+ :params value: if True, will perform QAT on torch.nn.Embedding layers
231+ using sparseml.pytorch.utils.quantization.prepare_embeddings_qat to fake
232+ quantize embedding weights
233+ """
234+ self ._quantize_embeddings = value
235+
211236 def initialize (
212237 self ,
213238 module : Module ,
@@ -350,6 +375,8 @@ def _enable_module_qat(self, module: Module):
350375 add_quant_dequant (quant_module )
351376 # set model to QAT mode
352377 torch_quantization .prepare_qat (quant_module , inplace = True )
378+ if self ._quantize_embeddings :
379+ prepare_embeddings_qat (quant_module )
353380 self ._qat_enabled = True
354381
355382 def _disable_quantization_observer_update_ready (self , epoch : float ) -> bool :
0 commit comments