1616from __future__ import absolute_import
1717from __future__ import division
1818from __future__ import print_function
19+
1920import warnings
2021
22+ import numpy as np
2123import tensorflow as tf
2224from tensorflow .keras import constraints
2325from tensorflow .keras import initializers
2628from tensorflow .keras .layers import Conv1D
2729from tensorflow .keras .layers import Conv2D
2830from tensorflow .keras .layers import Conv2DTranspose
29- from tensorflow .keras .layers import SeparableConv1D
30- from tensorflow .keras .layers import SeparableConv2D
3131from tensorflow .keras .layers import DepthwiseConv2D
3232from tensorflow .keras .layers import Dropout
3333from tensorflow .keras .layers import InputSpec
34- from tensorflow .python . eager import context
35- from tensorflow .python . ops import array_ops
36- # from tensorflow.python.ops import array_ops
34+ from tensorflow .keras . layers import SeparableConv1D
35+ from tensorflow .keras . layers import SeparableConv2D
36+
3737from .qlayers import get_auto_range_constraint_initializer
3838from .qlayers import QActivation
3939from .quantizers import get_quantized_initializer
4040from .quantizers import get_quantizer
41-
41+ from tensorflow .python .eager import context
42+ from tensorflow .python .ops import array_ops
43+ # from tensorflow.python.ops import array_ops
4244from tensorflow_model_optimization .python .core .sparsity .keras .prunable_layer import PrunableLayer
4345
4446
@@ -260,32 +262,36 @@ class QConv2D(Conv2D, PrunableLayer):
260262 # can go over [-1,+1], these values are used to set the clipping
261263 # value of kernels and biases, respectively, instead of using the
262264 # constraints specified by the user.
265+ # mask: Optional mask for kernel weights.
263266 #
264267 # we refer the reader to the documentation of Conv2D in Keras for the
265268 # other parameters.
266269 #
267270
268- def __init__ (self ,
269- filters ,
270- kernel_size ,
271- strides = (1 , 1 ),
272- padding = "valid" ,
273- data_format = "channels_last" ,
274- dilation_rate = (1 , 1 ),
275- activation = None ,
276- use_bias = True ,
277- kernel_initializer = "he_normal" ,
278- bias_initializer = "zeros" ,
279- kernel_regularizer = None ,
280- bias_regularizer = None ,
281- activity_regularizer = None ,
282- kernel_constraint = None ,
283- bias_constraint = None ,
284- kernel_range = None ,
285- bias_range = None ,
286- kernel_quantizer = None ,
287- bias_quantizer = None ,
288- ** kwargs ):
271+ def __init__ (
272+ self ,
273+ filters ,
274+ kernel_size ,
275+ strides = (1 , 1 ),
276+ padding = "valid" ,
277+ data_format = "channels_last" ,
278+ dilation_rate = (1 , 1 ),
279+ activation = None ,
280+ use_bias = True ,
281+ kernel_initializer = "he_normal" ,
282+ bias_initializer = "zeros" ,
283+ kernel_regularizer = None ,
284+ bias_regularizer = None ,
285+ activity_regularizer = None ,
286+ kernel_constraint = None ,
287+ bias_constraint = None ,
288+ kernel_range = None ,
289+ bias_range = None ,
290+ kernel_quantizer = None ,
291+ bias_quantizer = None ,
292+ mask = None ,
293+ ** kwargs ,
294+ ):
289295
290296 if kernel_range is not None :
291297 warnings .warn ("kernel_range is deprecated in QConv2D layer." )
@@ -324,6 +330,20 @@ def __init__(self,
324330 if activation is not None :
325331 activation = get_quantizer (activation )
326332
333+ if mask is not None :
334+ shape = mask .shape
335+ if len (shape ) < 2 :
336+ raise ValueError (
337+ "Expected shape to have rank at least 2 but provided shape has"
338+ f" rank { len (shape )} "
339+ )
340+ h , w = shape [0 ], shape [1 ]
341+ self ._mask = np .reshape (
342+ mask , (h , w , 1 , 1 )
343+ ) # Extend the dimension to be 4D.
344+ else :
345+ self ._mask = None
346+
327347 super ().__init__ (
328348 filters = filters ,
329349 kernel_size = kernel_size ,
@@ -343,19 +363,44 @@ def __init__(self,
343363 ** kwargs
344364 )
345365
366+ def convolution_op (self , inputs , kernel ):
367+ return tf .keras .backend .conv2d (
368+ inputs ,
369+ kernel ,
370+ strides = self .strides ,
371+ padding = self .padding ,
372+ data_format = self .data_format ,
373+ dilation_rate = self .dilation_rate ,
374+ )
375+
376+ @tf .function (jit_compile = True )
377+ def _jit_compiled_convolution_op (self , inputs , kernel ):
378+ return self .convolution_op (inputs , kernel )
379+
346380 def call (self , inputs ):
347381 if self .kernel_quantizer :
348382 quantized_kernel = self .kernel_quantizer_internal (self .kernel )
349383 else :
350384 quantized_kernel = self .kernel
351385
352- outputs = tf .keras .backend .conv2d (
353- inputs ,
354- quantized_kernel ,
355- strides = self .strides ,
356- padding = self .padding ,
357- data_format = self .data_format ,
358- dilation_rate = self .dilation_rate )
386+ if self ._mask is not None :
387+ # Apply mask to kernel weights if one is provided.
388+ quantized_kernel = quantized_kernel * self ._mask
389+
390+ # Grouped convolutions are not fully supported on the CPU for compiled
391+ # functions.
392+ #
393+ # This is a workaround taken from TF's core library. Remove when proper
394+ # support is added.
395+ # See definition of function "_jit_compiled_convolution_op" at
396+ # cs/third_party/py/tf_keras/layers/convolutional/base_conv.py for more
397+ # details.
398+ if self .groups > 1 :
399+ outputs = self ._jit_compiled_convolution_op (
400+ inputs , tf .convert_to_tensor (quantized_kernel )
401+ )
402+ else :
403+ outputs = self .convolution_op (inputs , quantized_kernel )
359404
360405 if self .use_bias :
361406 if self .bias_quantizer :
@@ -364,7 +409,8 @@ def call(self, inputs):
364409 quantized_bias = self .bias
365410
366411 outputs = tf .keras .backend .bias_add (
367- outputs , quantized_bias , data_format = self .data_format )
412+ outputs , quantized_bias , data_format = self .data_format
413+ )
368414
369415 if self .activation is not None :
370416 return self .activation (outputs )
@@ -380,10 +426,19 @@ def get_config(self):
380426 ),
381427 "kernel_range" : self .kernel_range ,
382428 "bias_range" : self .bias_range ,
429+ "mask" : self ._mask .tolist () if self ._mask is not None else None ,
383430 }
384- base_config = super (QConv2D , self ).get_config ()
431+ base_config = super ().get_config ()
385432 return dict (list (base_config .items ()) + list (config .items ()))
386433
434+ @classmethod
435+ def from_config (cls , config ):
436+ mask = config .get ("mask" )
437+ if mask is not None :
438+ mask = np .array (mask )
439+ config ["mask" ] = mask
440+ return cls (** config )
441+
387442 def get_quantization_config (self ):
388443 return {
389444 "kernel_quantizer" :
0 commit comments