@@ -58,16 +58,6 @@ def convert_image_to(img_type, image):
5858 return image .convert (img_type )
5959 return image
6060
61- # small helper modules
62-
63- class Residual (nn .Module ):
64- def __init__ (self , fn ):
65- super ().__init__ ()
66- self .fn = fn
67-
68- def forward (self , x , * args , ** kwargs ):
69- return self .fn (x , * args , ** kwargs ) + x
70-
7161# use layernorm without bias, more stable
7262
7363class LayerNorm (nn .Module ):
@@ -259,27 +249,6 @@ def forward(self, x, time = None):
259249
260250 return self .net (x )
261251
262- class FiLM (nn .Module ):
263- def __init__ (
264- self ,
265- dim ,
266- hidden_dim
267- ):
268- super ().__init__ ()
269- self .net = nn .Sequential (
270- nn .Linear (dim , hidden_dim * 4 ),
271- nn .SiLU (),
272- nn .Linear (hidden_dim * 4 , hidden_dim * 2 )
273- )
274-
275- nn .init .zeros_ (self .net [- 1 ].weight )
276- nn .init .zeros_ (self .net [- 1 ].bias )
277-
278- def forward (self , conditions , hiddens ):
279- scale , shift = self .net (conditions ).chunk (2 , dim = - 1 )
280- scale , shift = map (lambda t : rearrange (t , 'b d -> b 1 d' ), (scale , shift ))
281- return hiddens * (scale + 1 ) + shift
282-
283252# model
284253
285254class RIN (nn .Module ):
0 commit comments