11#!/usr/bin/env python
22# -*- coding: utf-8 -*-
3- """conv.py"""
4- # pylint: disable=protected-access, consider-using-enumerate, too-many-lines
5-
6- #
7- # Authors: Chiheb Trabelsi
83
94import tensorflow as tf
105from tensorflow .keras import backend as K
2419
2520def conv1d_transpose (
2621 inputs ,
27- filter , # pylint: disable=redefined-builtin
22+ filter ,
2823 kernel_size = None ,
2924 filters = None ,
3025 strides = (1 ,),
@@ -76,7 +71,7 @@ def conv1d_transpose(
7671
7772def conv2d_transpose (
7873 inputs ,
79- filter , # pylint: disable=redefined-builtin
74+ filter ,
8075 kernel_size = None ,
8176 filters = None ,
8277 strides = (1 , 1 ),
@@ -121,9 +116,7 @@ def conv2d_transpose(
121116 output_shape = (batch_size , out_height , out_width , filters )
122117
123118 filter = K .permute_dimensions (filter , (0 , 1 , 3 , 2 ))
124- return K .conv2d_transpose (
125- inputs , filter , output_shape , strides , padding = padding , data_format = data_format
126- )
119+ return K .conv2d_transpose (inputs , filter , output_shape , strides , padding = padding , data_format = data_format )
127120
128121
129122def ifft (f ):
@@ -136,9 +129,7 @@ def ifft2(f):
136129 raise NotImplementedError (str (f ))
137130
138131
139- def conv_transpose_output_length (
140- input_length , filter_size , padding , stride , dilation = 1 , output_padding = None
141- ):
132+ def conv_transpose_output_length (input_length , filter_size , padding , stride , dilation = 1 , output_padding = None ):
142133 """Rearrange arguments for compatibility with conv_output_length."""
143134 if dilation != 1 :
144135 msg = f"Dilation must be 1 for transposed convolution. "
@@ -287,14 +278,8 @@ def __init__(
287278 self .kernel_size = conv_utils .normalize_tuple (kernel_size , rank , "kernel_size" )
288279 self .strides = conv_utils .normalize_tuple (strides , rank , "strides" )
289280 self .padding = conv_utils .normalize_padding (padding )
290- self .data_format = (
291- "channels_last"
292- if rank == 1
293- else conv_utils .normalize_data_format (data_format )
294- )
295- self .dilation_rate = conv_utils .normalize_tuple (
296- dilation_rate , rank , "dilation_rate"
297- )
281+ self .data_format = "channels_last" if rank == 1 else conv_utils .normalize_data_format (data_format )
282+ self .dilation_rate = conv_utils .normalize_tuple (dilation_rate , rank , "dilation_rate" )
298283 self .activation = activations .get (activation )
299284 self .use_bias = use_bias
300285 self .normalize_weight = normalize_weight
@@ -336,10 +321,7 @@ def build(self, input_shape):
336321 else :
337322 channel_axis = - 1
338323 if input_shape [channel_axis ] is None :
339- raise ValueError (
340- "The channel dimension of the inputs "
341- "should be defined. Found `None`."
342- )
324+ raise ValueError ("The channel dimension of the inputs " "should be defined. Found `None`." )
343325 # Divide by 2 for real and complex input.
344326 input_dim = input_shape [channel_axis ] // 2
345327 if False and self .transposed :
@@ -421,9 +403,7 @@ def build(self, input_shape):
421403 self .bias = None
422404
423405 # Set input spec.
424- self .input_spec = InputSpec (
425- ndim = self .rank + 2 , axes = {channel_axis : input_dim * 2 }
426- )
406+ self .input_spec = InputSpec (ndim = self .rank + 2 , axes = {channel_axis : input_dim * 2 })
427407 self .built = True
428408
429409 def call (self , inputs , ** kwargs ):
@@ -457,9 +437,7 @@ def call(self, inputs, **kwargs):
457437 "strides" : self .strides [0 ] if self .rank == 1 else self .strides ,
458438 "padding" : self .padding ,
459439 "data_format" : self .data_format ,
460- "dilation_rate" : self .dilation_rate [0 ]
461- if self .rank == 1
462- else self .dilation_rate ,
440+ "dilation_rate" : self .dilation_rate [0 ] if self .rank == 1 else self .dilation_rate ,
463441 }
464442 if self .transposed :
465443 convArgs .pop ("dilation_rate" , None )
@@ -518,12 +496,8 @@ def call(self, inputs, **kwargs):
518496 broadcast_mu_imag = K .reshape (mu_imag , broadcast_mu_shape )
519497 reshaped_f_real_centred = reshaped_f_real - broadcast_mu_real
520498 reshaped_f_imag_centred = reshaped_f_imag - broadcast_mu_imag
521- Vrr = (
522- K .mean (reshaped_f_real_centred ** 2 , axis = reduction_axes ) + self .epsilon
523- )
524- Vii = (
525- K .mean (reshaped_f_imag_centred ** 2 , axis = reduction_axes ) + self .epsilon
526- )
499+ Vrr = K .mean (reshaped_f_real_centred ** 2 , axis = reduction_axes ) + self .epsilon
500+ Vii = K .mean (reshaped_f_imag_centred ** 2 , axis = reduction_axes ) + self .epsilon
527501 Vri = (
528502 K .mean (
529503 reshaped_f_real_centred * reshaped_f_imag_centred ,
@@ -558,9 +532,7 @@ def call(self, inputs, **kwargs):
558532
559533 cat_kernels_4_real = K .concatenate ([f_real , - f_imag ], axis = - 2 )
560534 cat_kernels_4_imag = K .concatenate ([f_imag , f_real ], axis = - 2 )
561- cat_kernels_4_complex = K .concatenate (
562- [cat_kernels_4_real , cat_kernels_4_imag ], axis = - 1
563- )
535+ cat_kernels_4_complex = K .concatenate ([cat_kernels_4_real , cat_kernels_4_imag ], axis = - 1 )
564536 if False and self .transposed :
565537 cat_kernels_4_complex ._keras_shape = self .kernel_size + (
566538 2 * self .filters ,
@@ -632,9 +604,7 @@ def get_config(self):
632604 "gamma_off_initializer" : sanitizedInitSer (self .gamma_off_initializer ),
633605 "kernel_regularizer" : regularizers .serialize (self .kernel_regularizer ),
634606 "bias_regularizer" : regularizers .serialize (self .bias_regularizer ),
635- "gamma_diag_regularizer" : regularizers .serialize (
636- self .gamma_diag_regularizer
637- ),
607+ "gamma_diag_regularizer" : regularizers .serialize (self .gamma_diag_regularizer ),
638608 "gamma_off_regularizer" : regularizers .serialize (self .gamma_off_regularizer ),
639609 "activity_regularizer" : regularizers .serialize (self .activity_regularizer ),
640610 "kernel_constraint" : constraints .serialize (self .kernel_constraint ),
@@ -1113,10 +1083,7 @@ def build(self, input_shape):
11131083 else :
11141084 channel_axis = - 1
11151085 if input_shape [channel_axis ] is None :
1116- raise ValueError (
1117- "The channel dimension of the inputs "
1118- "should be defined. Found `None`."
1119- )
1086+ raise ValueError ("The channel dimension of the inputs " "should be defined. Found `None`." )
11201087 input_dim = input_shape [channel_axis ]
11211088 gamma_shape = (input_dim * self .filters ,)
11221089 self .gamma = self .add_weight (
@@ -1134,32 +1101,22 @@ def call(self, inputs):
11341101 else :
11351102 channel_axis = - 1
11361103 if input_shape [channel_axis ] is None :
1137- raise ValueError (
1138- "The channel dimension of the inputs "
1139- "should be defined. Found `None`."
1140- )
1104+ raise ValueError ("The channel dimension of the inputs " "should be defined. Found `None`." )
11411105 input_dim = input_shape [channel_axis ]
11421106 ker_shape = self .kernel_size + (input_dim , self .filters )
11431107 nb_kernels = ker_shape [- 2 ] * ker_shape [- 1 ]
11441108 kernel_shape_4_norm = (np .prod (self .kernel_size ), nb_kernels )
11451109 reshaped_kernel = K .reshape (self .kernel , kernel_shape_4_norm )
1146- normalized_weight = K .l2_normalize (
1147- reshaped_kernel , axis = 0 , epsilon = self .epsilon
1148- )
1149- normalized_weight = (
1150- K .reshape (self .gamma , (1 , ker_shape [- 2 ] * ker_shape [- 1 ]))
1151- * normalized_weight
1152- )
1110+ normalized_weight = K .l2_normalize (reshaped_kernel , axis = 0 , epsilon = self .epsilon )
1111+ normalized_weight = K .reshape (self .gamma , (1 , ker_shape [- 2 ] * ker_shape [- 1 ])) * normalized_weight
11531112 shaped_kernel = K .reshape (normalized_weight , ker_shape )
11541113 shaped_kernel ._keras_shape = ker_shape
11551114
11561115 convArgs = {
11571116 "strides" : self .strides [0 ] if self .rank == 1 else self .strides ,
11581117 "padding" : self .padding ,
11591118 "data_format" : self .data_format ,
1160- "dilation_rate" : self .dilation_rate [0 ]
1161- if self .rank == 1
1162- else self .dilation_rate ,
1119+ "dilation_rate" : self .dilation_rate [0 ] if self .rank == 1 else self .dilation_rate ,
11631120 }
11641121 convFunc = {1 : K .conv1d , 2 : K .conv2d , 3 : K .conv3d }[self .rank ]
11651122 output = convFunc (inputs , shaped_kernel , ** convArgs )
0 commit comments