@@ -159,9 +159,10 @@ def PixelUnshuffle(self, x):
159159
160160
161161class BatchNorm1d (nn .BatchNorm1d , INNAbstract .INNModule ):
162- def __init__ (self , dim ):
162+ def __init__ (self , dim , requires_grad = False ):
163163 INNAbstract .INNModule .__init__ (self )
164164 nn .BatchNorm1d .__init__ (self , num_features = dim , affine = False )
165+ self .requires_grad = requires_grad
165166
166167 def forward (self , x , log_p = 0 , log_det_J = 0 ):
167168
@@ -171,7 +172,9 @@ def forward(self, x, log_p=0, log_det_J=0):
171172 var = self .running_var # [dim]
172173 else :
173174 # if in training
174- var = torch .var (x , dim = 0 , unbiased = False ).detach () # [dim]
175+ var = torch .var (x , dim = 0 , unbiased = False )#.detach() # [dim]
176+ if not self .requires_grad :
177+ var = var .detach ()
175178
176179 x = super (BatchNorm1d , self ).forward (x )
177180
@@ -211,11 +214,11 @@ def inverse(self, y, **args):
211214
212215class RealNVP (INNAbstract .INNModule ):
213216
214- def __init__ (self , dim = None , f_log_s = None , f_t = None , k = 4 , mask = None , clip = 1 ):
217+ def __init__ (self , dim = None , f_log_s = None , f_t = None , k = 4 , mask = None , clip = 1 , activation_fn = None ):
215218 super (RealNVP , self ).__init__ ()
216219 if (f_log_s is None ) and (f_t is None ):
217- log_s = utilities .default_net (dim , k )#self.default_net(dim, k)
218- t = utilities .default_net (dim , k )#self.default_net(dim, k)
220+ log_s = utilities .default_net (dim , k , activation_fn )#self.default_net(dim, k)
221+ t = utilities .default_net (dim , k , activation_fn )#self.default_net(dim, k)
219222 self .net = utilities .combined_real_nvp (dim , log_s , t , mask , clip )
220223 else :
221224 self .net = utilities .combined_real_nvp (dim , f_log_s , f_t , mask , clip )
@@ -234,11 +237,11 @@ def inverse(self, y, **args):
234237
235238class NICE (INNAbstract .INNModule ):
236239
237- def __init__ (self , dim = None , m = None , mask = None , k = 4 ):
240+ def __init__ (self , dim = None , m = None , mask = None , k = 4 , activation_fn = None ):
238241 super (NICE , self ).__init__ ()
239242
240243 if m is None :
241- m_ = utilities .default_net (dim , k )
244+ m_ = utilities .default_net (dim , k , activation_fn )
242245 self .net = utilities .NICE (dim , m = m_ , mask = mask )
243246 else :
244247 self .net = utilities .NICE (dim , m = m , mask = mask )
@@ -268,18 +271,18 @@ class Nonlinear(INNAbstract.INNModule):
268271 '''
269272 Nonlinear invertible block
270273 '''
271- def __init__ (self , dim , method = 'NICE ' , m = None , mask = None , k = 4 , ** args ):
274+ def __init__ (self , dim , method = 'RealNVP ' , m = None , mask = None , k = 4 , activation_fn = None , ** args ):
272275 super (Nonlinear , self ).__init__ ()
273276
274277 self .method = method
275278 if method == 'NICE' :
276- self .block = NICE (dim , m = m , mask = mask , k = k )
279+ self .block = NICE (dim , m = m , mask = mask , k = k , activation_fn = activation_fn )
277280 if method == 'RealNVP' :
278281 clip = _default_dict ('clip' , args , 1 )
279282 f_log_s = _default_dict ('f_log_s' , args , None )
280283 f_t = _default_dict ('f_t' , args , None )
281284
282- self .block = RealNVP (dim = dim , f_log_s = f_log_s , f_t = f_t , k = k , mask = mask , clip = clip )
285+ self .block = RealNVP (dim = dim , f_log_s = f_log_s , f_t = f_t , k = k , mask = mask , clip = clip , activation_fn = activation_fn )
283286 if method == 'iResNet' :
284287 g = _default_dict ('g' , args , None )
285288 beta = _default_dict ('beta' , args , 0.8 )
@@ -293,4 +296,70 @@ def forward(self, x, log_p0=0, log_det_J=0):
293296 return self .block (x , log_p0 , log_det_J )
294297
295298 def inverse (self , y , ** args ):
296- return self .block .inverse (y , ** args )
299+ return self .block .inverse (y , ** args )
300+
301+ class ResizeFeatures (INNAbstract .INNModule ):
302+ '''
303+ Resize for n-d input, include linear or multi-channel inputs
304+ '''
305+ def __init__ (self , feature_in , feature_out , dist = 'normal' ):
306+ super (ResizeFeatures , self ).__init__ ()
307+ self .feature_in = feature_in
308+ self .feature_out = feature_out
309+
310+ if dist == 'normal' :
311+ self .dist = utilities .NormalDistribution ()
312+ elif isinstance (dist , INNAbstract .Distribution ):
313+ self .dist = dist
314+
315+ def resize (self , x , feature_in , feature_out ):
316+ '''
317+ x has two kinds of shapes:
318+ 1. [feature_in]
319+ 2. [batch_size, feature_in, *]
320+ '''
321+ if len (x .shape ) == 1 :
322+ # [feature_in]
323+ if x .shape [0 ] != self .feature_in :
324+ raise Exception (f'Expect to get { self .feature_in } features, but got { x .shape [0 ]} .' )
325+ y , z = x [:feature_out ], x [feature_out :]
326+
327+ if len (x .shape ) >= 2 :
328+ # [batch_size, feature_in, *]
329+ if x .shape [1 ] != self .feature_in :
330+ raise Exception (f'Expect to get { self .feature_in } features, but got { x .shape [1 ]} .' )
331+ y , z = x [:, :feature_out ], x [:, feature_out :]
332+
333+ return y , z
334+
335+ def forward (self , x , log_p0 = 0 , log_det_J = 0 ):
336+ x , z = self .resize (x , self .feature_in , self .feature_out )
337+ if self .compute_p :
338+ p = self .dist .logp (z )
339+ return x , log_p0 + p , log_det_J
340+ else :
341+ return x
342+
343+ def inverse (self , y , ** args ):
344+ '''
345+ y has two kinds of shapes:
346+ 1. [feature_in]
347+ 2. [batch_size, feature_in, *]
348+ '''
349+ if len (y .shape ) == 1 :
350+ # [feature_in]
351+ if y .shape [0 ] != self .feature_out :
352+ raise Exception (f'Expect to get { self .feature_out } features, but got { y .shape [0 ]} .' )
353+ z = self .dist .sample (self .feature_in - self .feature_out ).to (y .device )
354+ y = torch .cat ([y , z ])
355+
356+ if len (y .shape ) >= 2 :
357+ # [batch_size, feature_in, *]
358+ if y .shape [1 ] != self .feature_out :
359+ raise Exception (f'Expect to get { self .feature_out } features, but got { y .shape [1 ]} .' )
360+ shape = list (y .shape )
361+ shape [1 ] = self .feature_in - self .feature_out
362+ z = self .dist .sample (shape ).to (y .device )
363+ y = torch .cat ([y , z ], dim = 1 )
364+
365+ return y
0 commit comments