2222"""
2323
2424import numpy as np
25+ from typing import Tuple
2526import keras
2627import keras .ops as K
2728from keras .saving import register_keras_serializable
@@ -91,6 +92,7 @@ def __init__(
9192 data_format = data_format ,
9293 ** kwargs ,
9394 )
95+ self .built = False
9496 self .set_klip_factor (k_coef_lip )
9597 self ._kwargs = kwargs
9698
@@ -181,6 +183,7 @@ def __init__(
181183 data_format = data_format ,
182184 ** kwargs ,
183185 )
186+ self .built = False
184187 self .set_klip_factor (k_coef_lip )
185188 self .eps_grad_sqrt = eps_grad_sqrt
186189 self ._kwargs = kwargs
@@ -246,6 +249,7 @@ def __init__(self, data_format=None, k_coef_lip=1.0, eps_grad_sqrt=1e-6, **kwarg
246249 super (ScaledGlobalL2NormPooling2D , self ).__init__ (
247250 data_format = data_format , ** kwargs
248251 )
252+ self .built = False
249253 self .set_klip_factor (k_coef_lip )
250254 self .eps_grad_sqrt = eps_grad_sqrt
251255 self ._kwargs = kwargs
@@ -308,6 +312,7 @@ def __init__(self, data_format=None, k_coef_lip=1.0, **kwargs):
308312 super (ScaledGlobalAveragePooling2D , self ).__init__ (
309313 data_format = data_format , ** kwargs
310314 )
315+ self .built = False
311316 self .set_klip_factor (k_coef_lip )
312317 self ._kwargs = kwargs
313318
@@ -363,32 +368,44 @@ def __init__(
363368 **kwargs: params passed to the Layers constructor
364369 """
365370 super (InvertibleDownSampling , self ).__init__ (name = name , dtype = dtype , ** kwargs )
366- self .pool_size = pool_size
367371 self .data_format = data_format
368372
369- def call (self , inputs ):
370- if self .data_format == "channels_last" :
371- return K .concatenate (
372- [
373- inputs [
374- :, i :: self .pool_size [0 ], j :: self .pool_size [1 ], :
375- ] # for now we handle only channels last
376- for i in range (self .pool_size [0 ])
377- for j in range (self .pool_size [1 ])
378- ],
379- axis = - 1 ,
380- )
373+ ndims = 2
374+ ks : Tuple [int , ...]
375+ if isinstance (pool_size , int ):
376+ ks = (pool_size ,) * ndims
381377 else :
382- return K .concatenate (
383- [
384- inputs [
385- :, :, i :: self .pool_size [0 ], j :: self .pool_size [1 ]
386- ] # for now we handle only channels last
387- for i in range (self .pool_size [0 ])
388- for j in range (self .pool_size [1 ])
389- ],
390- axis = 1 ,
378+ ks = tuple (pool_size )
379+
380+ if len (ks ) != ndims :
381+ raise ValueError (
382+ f"Expected { ndims } -dimensional pool_size, but "
383+ f"got { len (ks )} -dimensional instead"
384+ )
385+ self .pool_size = ks
386+
387+ def call (self , inputs ):
388+ if self .data_format == "channels_first" :
389+ # convert to channels_first
390+ inputs = K .transpose (inputs , [0 , 2 , 3 , 1 ])
391+ # from shape (bs, w*pw, h*ph, c) to (bs, w, h, c, pw, ph)
392+ input_shape = K .shape (inputs )
393+ w , h , c_in = input_shape [1 ], input_shape [2 ], input_shape [3 ]
394+ pw , ph = self .pool_size
395+ wo = w // pw
396+ ho = h // ph
397+ inputs = K .reshape (inputs , (- 1 , wo , pw , h , c_in ))
398+ inputs = K .reshape (inputs , (- 1 , wo , pw , ho , ph , c_in ))
399+ inputs = K .transpose (
400+ inputs , [0 , 1 , 3 , 5 , 2 , 4 ]
401+ ) # (bs, wo, pw, ho, ph, c) -> (bs, wo, ho, c, pw, ph)
402+ inputs = K .reshape (inputs , (- 1 , wo , ho , c_in * pw * ph ))
403+
404+ if self .data_format == "channels_first" :
405+ inputs = K .transpose (
406+ inputs , [0 , 3 , 1 , 2 ] # (bs, w, h, c*pw*ph) -> (bs, c*pw*ph, w, h) ->
391407 )
408+ return inputs
392409
393410 def get_config (self ):
394411 config = {
@@ -427,9 +444,22 @@ def __init__(
427444 **kwargs: params passed to the Layers constructor
428445 """
429446 super (InvertibleUpSampling , self ).__init__ (name = name , dtype = dtype , ** kwargs )
430- self .pool_size = pool_size
431447 self .data_format = data_format
432448
449+ ndims = 2
450+ ks : Tuple [int , ...]
451+ if isinstance (pool_size , int ):
452+ ks = (pool_size ,) * ndims
453+ else :
454+ ks = tuple (pool_size )
455+
456+ if len (ks ) != ndims :
457+ raise ValueError (
458+ f"Expected { ndims } -dimensional pool_size, but "
459+ f"got { len (ks )} -dimensional instead"
460+ )
461+ self .pool_size = ks
462+
433463 def call (self , inputs ):
434464 if self .data_format == "channels_first" :
435465 # convert to channels_first
@@ -439,12 +469,12 @@ def call(self, inputs):
439469 w , h , c_in = input_shape [1 ], input_shape [2 ], input_shape [3 ]
440470 pw , ph = self .pool_size
441471 c = c_in // (pw * ph )
442- inputs = K .reshape (inputs , (- 1 , w , h , pw , ph , c ))
472+ inputs = K .reshape (inputs , (- 1 , w , h , c , pw , ph ))
443473 inputs = K .transpose (
444474 K .reshape (
445475 K .transpose (
446- inputs , [0 , 5 , 2 , 4 , 1 , 3 ]
447- ), # (bs, w, h, pw, ph, c ) -> (bs, c, w, pw, h, ph)
476+ inputs , [0 , 3 , 2 , 5 , 1 , 4 ]
477+ ), # (bs, w, h, c, pw, ph ) -> (bs, c, w, pw, h, ph)
448478 (- 1 , c , w , pw , h * ph ),
449479 ), # (bs, c, w, pw, h, ph) -> (bs, c, w, pw, h*ph) merge last axes
450480 [
0 commit comments