Skip to content

Commit 24a4cfe

Browse files
committed
Black formatting preparing for github action run
1 parent ca4495c commit 24a4cfe

File tree

12 files changed

+619
-711
lines changed

12 files changed

+619
-711
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,4 +42,6 @@ venv
4242

4343
.vscode
4444

45-
dist/*
45+
dist/*
46+
47+
__pycache__/*

complexnn/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#
77
# What this module includes by default:
88
from . import bn, conv, dense, init, norm, pool
9+
910
# from . import fft
1011

1112
from .bn import ComplexBatchNormalization as ComplexBN
@@ -17,6 +18,7 @@
1718
WeightNorm_Conv,
1819
)
1920
from .dense import ComplexDense
21+
2022
# from .fft import (fft, ifft, fft2, ifft2, FFT, IFFT, FFT2, IFFT2)
2123
from .init import (
2224
ComplexIndependentFilters,

complexnn/bn.py

Lines changed: 150 additions & 142 deletions
Large diffs are not rendered by default.

complexnn/conv.py

Lines changed: 18 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
11
#!/usr/bin/env python
22
# -*- coding: utf-8 -*-
3-
"""conv.py"""
4-
# pylint: disable=protected-access, consider-using-enumerate, too-many-lines
5-
6-
#
7-
# Authors: Chiheb Trabelsi
83

94
import tensorflow as tf
105
from tensorflow.keras import backend as K
@@ -24,7 +19,7 @@
2419

2520
def conv1d_transpose(
2621
inputs,
27-
filter, # pylint: disable=redefined-builtin
22+
filter,
2823
kernel_size=None,
2924
filters=None,
3025
strides=(1,),
@@ -76,7 +71,7 @@ def conv1d_transpose(
7671

7772
def conv2d_transpose(
7873
inputs,
79-
filter, # pylint: disable=redefined-builtin
74+
filter,
8075
kernel_size=None,
8176
filters=None,
8277
strides=(1, 1),
@@ -121,9 +116,7 @@ def conv2d_transpose(
121116
output_shape = (batch_size, out_height, out_width, filters)
122117

123118
filter = K.permute_dimensions(filter, (0, 1, 3, 2))
124-
return K.conv2d_transpose(
125-
inputs, filter, output_shape, strides, padding=padding, data_format=data_format
126-
)
119+
return K.conv2d_transpose(inputs, filter, output_shape, strides, padding=padding, data_format=data_format)
127120

128121

129122
def ifft(f):
@@ -136,9 +129,7 @@ def ifft2(f):
136129
raise NotImplementedError(str(f))
137130

138131

139-
def conv_transpose_output_length(
140-
input_length, filter_size, padding, stride, dilation=1, output_padding=None
141-
):
132+
def conv_transpose_output_length(input_length, filter_size, padding, stride, dilation=1, output_padding=None):
142133
"""Rearrange arguments for compatibility with conv_output_length."""
143134
if dilation != 1:
144135
msg = f"Dilation must be 1 for transposed convolution. "
@@ -287,14 +278,8 @@ def __init__(
287278
self.kernel_size = conv_utils.normalize_tuple(kernel_size, rank, "kernel_size")
288279
self.strides = conv_utils.normalize_tuple(strides, rank, "strides")
289280
self.padding = conv_utils.normalize_padding(padding)
290-
self.data_format = (
291-
"channels_last"
292-
if rank == 1
293-
else conv_utils.normalize_data_format(data_format)
294-
)
295-
self.dilation_rate = conv_utils.normalize_tuple(
296-
dilation_rate, rank, "dilation_rate"
297-
)
281+
self.data_format = "channels_last" if rank == 1 else conv_utils.normalize_data_format(data_format)
282+
self.dilation_rate = conv_utils.normalize_tuple(dilation_rate, rank, "dilation_rate")
298283
self.activation = activations.get(activation)
299284
self.use_bias = use_bias
300285
self.normalize_weight = normalize_weight
@@ -336,10 +321,7 @@ def build(self, input_shape):
336321
else:
337322
channel_axis = -1
338323
if input_shape[channel_axis] is None:
339-
raise ValueError(
340-
"The channel dimension of the inputs "
341-
"should be defined. Found `None`."
342-
)
324+
raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.")
343325
# Divide by 2 for real and complex input.
344326
input_dim = input_shape[channel_axis] // 2
345327
if False and self.transposed:
@@ -421,9 +403,7 @@ def build(self, input_shape):
421403
self.bias = None
422404

423405
# Set input spec.
424-
self.input_spec = InputSpec(
425-
ndim=self.rank + 2, axes={channel_axis: input_dim * 2}
426-
)
406+
self.input_spec = InputSpec(ndim=self.rank + 2, axes={channel_axis: input_dim * 2})
427407
self.built = True
428408

429409
def call(self, inputs, **kwargs):
@@ -457,9 +437,7 @@ def call(self, inputs, **kwargs):
457437
"strides": self.strides[0] if self.rank == 1 else self.strides,
458438
"padding": self.padding,
459439
"data_format": self.data_format,
460-
"dilation_rate": self.dilation_rate[0]
461-
if self.rank == 1
462-
else self.dilation_rate,
440+
"dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate,
463441
}
464442
if self.transposed:
465443
convArgs.pop("dilation_rate", None)
@@ -518,12 +496,8 @@ def call(self, inputs, **kwargs):
518496
broadcast_mu_imag = K.reshape(mu_imag, broadcast_mu_shape)
519497
reshaped_f_real_centred = reshaped_f_real - broadcast_mu_real
520498
reshaped_f_imag_centred = reshaped_f_imag - broadcast_mu_imag
521-
Vrr = (
522-
K.mean(reshaped_f_real_centred**2, axis=reduction_axes) + self.epsilon
523-
)
524-
Vii = (
525-
K.mean(reshaped_f_imag_centred**2, axis=reduction_axes) + self.epsilon
526-
)
499+
Vrr = K.mean(reshaped_f_real_centred**2, axis=reduction_axes) + self.epsilon
500+
Vii = K.mean(reshaped_f_imag_centred**2, axis=reduction_axes) + self.epsilon
527501
Vri = (
528502
K.mean(
529503
reshaped_f_real_centred * reshaped_f_imag_centred,
@@ -558,9 +532,7 @@ def call(self, inputs, **kwargs):
558532

559533
cat_kernels_4_real = K.concatenate([f_real, -f_imag], axis=-2)
560534
cat_kernels_4_imag = K.concatenate([f_imag, f_real], axis=-2)
561-
cat_kernels_4_complex = K.concatenate(
562-
[cat_kernels_4_real, cat_kernels_4_imag], axis=-1
563-
)
535+
cat_kernels_4_complex = K.concatenate([cat_kernels_4_real, cat_kernels_4_imag], axis=-1)
564536
if False and self.transposed:
565537
cat_kernels_4_complex._keras_shape = self.kernel_size + (
566538
2 * self.filters,
@@ -632,9 +604,7 @@ def get_config(self):
632604
"gamma_off_initializer": sanitizedInitSer(self.gamma_off_initializer),
633605
"kernel_regularizer": regularizers.serialize(self.kernel_regularizer),
634606
"bias_regularizer": regularizers.serialize(self.bias_regularizer),
635-
"gamma_diag_regularizer": regularizers.serialize(
636-
self.gamma_diag_regularizer
637-
),
607+
"gamma_diag_regularizer": regularizers.serialize(self.gamma_diag_regularizer),
638608
"gamma_off_regularizer": regularizers.serialize(self.gamma_off_regularizer),
639609
"activity_regularizer": regularizers.serialize(self.activity_regularizer),
640610
"kernel_constraint": constraints.serialize(self.kernel_constraint),
@@ -1113,10 +1083,7 @@ def build(self, input_shape):
11131083
else:
11141084
channel_axis = -1
11151085
if input_shape[channel_axis] is None:
1116-
raise ValueError(
1117-
"The channel dimension of the inputs "
1118-
"should be defined. Found `None`."
1119-
)
1086+
raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.")
11201087
input_dim = input_shape[channel_axis]
11211088
gamma_shape = (input_dim * self.filters,)
11221089
self.gamma = self.add_weight(
@@ -1134,32 +1101,22 @@ def call(self, inputs):
11341101
else:
11351102
channel_axis = -1
11361103
if input_shape[channel_axis] is None:
1137-
raise ValueError(
1138-
"The channel dimension of the inputs "
1139-
"should be defined. Found `None`."
1140-
)
1104+
raise ValueError("The channel dimension of the inputs " "should be defined. Found `None`.")
11411105
input_dim = input_shape[channel_axis]
11421106
ker_shape = self.kernel_size + (input_dim, self.filters)
11431107
nb_kernels = ker_shape[-2] * ker_shape[-1]
11441108
kernel_shape_4_norm = (np.prod(self.kernel_size), nb_kernels)
11451109
reshaped_kernel = K.reshape(self.kernel, kernel_shape_4_norm)
1146-
normalized_weight = K.l2_normalize(
1147-
reshaped_kernel, axis=0, epsilon=self.epsilon
1148-
)
1149-
normalized_weight = (
1150-
K.reshape(self.gamma, (1, ker_shape[-2] * ker_shape[-1]))
1151-
* normalized_weight
1152-
)
1110+
normalized_weight = K.l2_normalize(reshaped_kernel, axis=0, epsilon=self.epsilon)
1111+
normalized_weight = K.reshape(self.gamma, (1, ker_shape[-2] * ker_shape[-1])) * normalized_weight
11531112
shaped_kernel = K.reshape(normalized_weight, ker_shape)
11541113
shaped_kernel._keras_shape = ker_shape
11551114

11561115
convArgs = {
11571116
"strides": self.strides[0] if self.rank == 1 else self.strides,
11581117
"padding": self.padding,
11591118
"data_format": self.data_format,
1160-
"dilation_rate": self.dilation_rate[0]
1161-
if self.rank == 1
1162-
else self.dilation_rate,
1119+
"dilation_rate": self.dilation_rate[0] if self.rank == 1 else self.dilation_rate,
11631120
}
11641121
convFunc = {1: K.conv1d, 2: K.conv2d, 3: K.conv3d}[self.rank]
11651122
output = convFunc(inputs, shaped_kernel, **convArgs)

0 commit comments

Comments
 (0)