Skip to content

Commit 3c527cd

Browse files
committed
add self.built=False to counter issue #20871 -corrected in PR #20880- + update parameters usage in InvertibleUpDownSampling
1 parent d7e5dd2 commit 3c527cd

File tree

4 files changed

+58
-50
lines changed

4 files changed

+58
-50
lines changed

deel/lip/layers/pooling.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
"""
2323

2424
import numpy as np
25+
from typing import Tuple
2526
import keras
2627
import keras.ops as K
2728
from keras.saving import register_keras_serializable
@@ -91,6 +92,7 @@ def __init__(
9192
data_format=data_format,
9293
**kwargs,
9394
)
95+
self.built = False
9496
self.set_klip_factor(k_coef_lip)
9597
self._kwargs = kwargs
9698

@@ -181,6 +183,7 @@ def __init__(
181183
data_format=data_format,
182184
**kwargs,
183185
)
186+
self.built = False
184187
self.set_klip_factor(k_coef_lip)
185188
self.eps_grad_sqrt = eps_grad_sqrt
186189
self._kwargs = kwargs
@@ -246,6 +249,7 @@ def __init__(self, data_format=None, k_coef_lip=1.0, eps_grad_sqrt=1e-6, **kwarg
246249
super(ScaledGlobalL2NormPooling2D, self).__init__(
247250
data_format=data_format, **kwargs
248251
)
252+
self.built = False
249253
self.set_klip_factor(k_coef_lip)
250254
self.eps_grad_sqrt = eps_grad_sqrt
251255
self._kwargs = kwargs
@@ -308,6 +312,7 @@ def __init__(self, data_format=None, k_coef_lip=1.0, **kwargs):
308312
super(ScaledGlobalAveragePooling2D, self).__init__(
309313
data_format=data_format, **kwargs
310314
)
315+
self.built = False
311316
self.set_klip_factor(k_coef_lip)
312317
self._kwargs = kwargs
313318

@@ -363,32 +368,44 @@ def __init__(
363368
**kwargs: params passed to the Layers constructor
364369
"""
365370
super(InvertibleDownSampling, self).__init__(name=name, dtype=dtype, **kwargs)
366-
self.pool_size = pool_size
367371
self.data_format = data_format
368372

369-
def call(self, inputs):
370-
if self.data_format == "channels_last":
371-
return K.concatenate(
372-
[
373-
inputs[
374-
:, i :: self.pool_size[0], j :: self.pool_size[1], :
375-
] # for now we handle only channels last
376-
for i in range(self.pool_size[0])
377-
for j in range(self.pool_size[1])
378-
],
379-
axis=-1,
380-
)
373+
ndims = 2
374+
ks: Tuple[int, ...]
375+
if isinstance(pool_size, int):
376+
ks = (pool_size,) * ndims
381377
else:
382-
return K.concatenate(
383-
[
384-
inputs[
385-
:, :, i :: self.pool_size[0], j :: self.pool_size[1]
386-
] # for now we handle only channels last
387-
for i in range(self.pool_size[0])
388-
for j in range(self.pool_size[1])
389-
],
390-
axis=1,
378+
ks = tuple(pool_size)
379+
380+
if len(ks) != ndims:
381+
raise ValueError(
382+
f"Expected {ndims}-dimensional pool_size, but "
383+
f"got {len(ks)}-dimensional instead"
384+
)
385+
self.pool_size = ks
386+
387+
def call(self, inputs):
388+
if self.data_format == "channels_first":
389+
# convert to channels_first
390+
inputs = K.transpose(inputs, [0, 2, 3, 1])
391+
# from shape (bs, w*pw, h*ph, c) to (bs, w, h, c, pw, ph)
392+
input_shape = K.shape(inputs)
393+
w, h, c_in = input_shape[1], input_shape[2], input_shape[3]
394+
pw, ph = self.pool_size
395+
wo = w // pw
396+
ho = h // ph
397+
inputs = K.reshape(inputs, (-1, wo, pw, h, c_in))
398+
inputs = K.reshape(inputs, (-1, wo, pw, ho, ph, c_in))
399+
inputs = K.transpose(
400+
inputs, [0, 1, 3, 5, 2, 4]
401+
) # (bs, wo, pw, ho, ph, c) -> (bs, wo, ho, c, pw, ph)
402+
inputs = K.reshape(inputs, (-1, wo, ho, c_in * pw * ph))
403+
404+
if self.data_format == "channels_first":
405+
inputs = K.transpose(
406+
inputs, [0, 3, 1, 2] # (bs, w, h, c*pw*ph) -> (bs, c*pw*ph, w, h) ->
391407
)
408+
return inputs
392409

393410
def get_config(self):
394411
config = {
@@ -427,9 +444,22 @@ def __init__(
427444
**kwargs: params passed to the Layers constructor
428445
"""
429446
super(InvertibleUpSampling, self).__init__(name=name, dtype=dtype, **kwargs)
430-
self.pool_size = pool_size
431447
self.data_format = data_format
432448

449+
ndims = 2
450+
ks: Tuple[int, ...]
451+
if isinstance(pool_size, int):
452+
ks = (pool_size,) * ndims
453+
else:
454+
ks = tuple(pool_size)
455+
456+
if len(ks) != ndims:
457+
raise ValueError(
458+
f"Expected {ndims}-dimensional pool_size, but "
459+
f"got {len(ks)}-dimensional instead"
460+
)
461+
self.pool_size = ks
462+
433463
def call(self, inputs):
434464
if self.data_format == "channels_first":
435465
# convert to channels_first
@@ -439,12 +469,12 @@ def call(self, inputs):
439469
w, h, c_in = input_shape[1], input_shape[2], input_shape[3]
440470
pw, ph = self.pool_size
441471
c = c_in // (pw * ph)
442-
inputs = K.reshape(inputs, (-1, w, h, pw, ph, c))
472+
inputs = K.reshape(inputs, (-1, w, h, c, pw, ph))
443473
inputs = K.transpose(
444474
K.reshape(
445475
K.transpose(
446-
inputs, [0, 5, 2, 4, 1, 3]
447-
), # (bs, w, h, pw, ph, c) -> (bs, c, w, pw, h, ph)
476+
inputs, [0, 3, 2, 5, 1, 4]
477+
), # (bs, w, h, c, pw, ph) -> (bs, c, w, pw, h, ph)
448478
(-1, c, w, pw, h * ph),
449479
), # (bs, c, w, pw, h, ph) -> (bs, c, w, pw, h*ph) merge last axes
450480
[

deel/lip/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def evaluate_lip_const(model: keras.Model, x, eps=1e-4, seed=None):
9393
return keras.ops.max(spectral_norms).numpy()
9494

9595

96-
9796
def _padding_circular(x, circular_paddings):
9897
"""Add circular padding to a 4-D tensor. Only channels_last is supported."""
9998
if circular_paddings is None:

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ deps =
1919

2020

2121
commands =
22-
python -m unittest
22+
pytest tests
2323

2424
[testenv:py{39,310,311}-lint]
2525
skip_install = true

tests/utils_framework.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22
import os
33
import warnings
44
from functools import partial
5-
import pytest
65
import numpy as np
76

8-
import keras
97
import keras.utils as K
108
import tensorflow as tf
119

@@ -24,9 +22,6 @@
2422
# from tensorflow.keras.optimizers import SGD, Adam
2523
from keras.optimizers import SGD, Adam
2624

27-
# from tensorflow.keras.losses import CategoricalCrossentropy as CategoricalCrossentropy
28-
from keras.losses import CategoricalCrossentropy as CategoricalCrossentropy
29-
3025
# from tensorflow.keras.layers import Layer as tLayer
3126
from keras.layers import Layer as tLayer
3227

@@ -36,16 +31,6 @@
3631
# from tensorflow.keras.layers import Flatten
3732
from keras.layers import Flatten
3833

39-
40-
from keras.utils import set_random_seed as set_seed
41-
42-
# from tensorflow.keras import backend as K
43-
# from tensorflow.keras.metrics import mse as tmse
44-
from keras.metrics import MeanSquaredError as tmse
45-
46-
# from tensorflow.keras.losses import MeanSquaredError as MeanSquaredError
47-
from keras.losses import MeanSquaredError as MeanSquaredError
48-
4934
# from tensorflow import int32 as type_int32
5035
from tensorflow import int32 as type_int32
5136

@@ -79,9 +64,6 @@
7964
# from tensorflow.keras.layers import Concatenate as tConcatenate
8065
from keras.layers import Concatenate as tConcatenate
8166

82-
import sys
83-
84-
sys.path.append(".")
8567
from deel.lip.activations import GroupSort as GroupSort
8668
from deel.lip.activations import GroupSort2 as GroupSort2
8769
from deel.lip.activations import Householder as HouseHolder
@@ -118,7 +100,6 @@
118100
from deel.lip.losses import MulticlassHinge as HingeMulticlassLoss
119101
from deel.lip.losses import MulticlassHKR as HKRMulticlassLoss
120102

121-
####from deel.lip.losses import MulticlassSoftHKR as SoftHKRMulticlassLoss
122103
from deel.lip.losses import MultiMargin as MultiMarginLoss
123104
from deel.lip.losses import TauCategoricalCrossentropy as TauCategoricalCrossentropyLoss
124105
from deel.lip.losses import (
@@ -176,7 +157,7 @@
176157
"CondenseCallback",
177158
"MonitorCallback",
178159
"Sequential",
179-
"ScaledGlobalL2NormPool2d",
160+
"ScaledAdaptativeL2NormPool2d",
180161
"evaluate_lip_const",
181162
"DEFAULT_NITER_SPECTRAL_INIT",
182163
"Loss",
@@ -209,9 +190,7 @@
209190
]
210191

211192
FIT = "fit"
212-
# if tf.__version__.startswith("2.0") else "fit"
213193
EVALUATE = "evaluate"
214-
# if tf.__version__.startswith("2.0") else "evaluate"
215194

216195
MODEL_PATH = "model"
217196
EXTENSION = ".keras"

0 commit comments

Comments
 (0)