@@ -65,42 +65,20 @@ def __init__(self, fn):
6565 def forward (self , x , * args , ** kwargs ):
6666 return self .fn (x , * args , ** kwargs ) + x
6767
68- def Upsample (dim , dim_out = None ):
69- return nn .Sequential (
70- nn .Upsample (scale_factor = 2 , mode = 'nearest' ),
71- nn .Conv2d (dim , default (dim_out , dim ), 3 , padding = 1 )
72- )
73-
74- def Downsample (dim , dim_out = None ):
75- return nn .Conv2d (dim , default (dim_out , dim ), 4 , 2 , 1 )
68+ # use layernorm without bias, more stable
7669
7770class LayerNorm (nn .Module ):
7871 def __init__ (self , dim ):
7972 super ().__init__ ()
80- self .g = nn .Parameter (torch .ones (1 , dim , 1 , 1 ))
81-
82- def forward (self , x ):
83- eps = 1e-5 if x .dtype == torch .float32 else 1e-3
84- var = torch .var (x , dim = 1 , unbiased = False , keepdim = True )
85- mean = torch .mean (x , dim = 1 , keepdim = True )
86- return (x - mean ) * var .clamp (min = eps ).rsqrt () * self .g
87-
88- class PreNorm (nn .Module ):
89- def __init__ (self , dim , fn ):
90- super ().__init__ ()
91- self .fn = fn
92- self .norm = LayerNorm (dim )
73+ self .gamma = nn .Parameter (torch .ones (dim ))
74+ self .register_buffer ("beta" , torch .zeros (dim ))
9375
9476 def forward (self , x ):
95- x = self .norm (x )
96- return self .fn (x )
77+ return F .layer_norm (x , x .shape [- 1 :], self .gamma , self .beta )
9778
9879# positional embeds
9980
10081class LearnedSinusoidalPosEmb (nn .Module ):
101- """ following @crowsonkb 's lead with learned sinusoidal pos emb """
102- """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """
103-
10482 def __init__ (self , dim ):
10583 super ().__init__ ()
10684 assert (dim % 2 ) == 0
@@ -114,54 +92,13 @@ def forward(self, x):
11492 fouriered = torch .cat ((x , fouriered ), dim = - 1 )
11593 return fouriered
11694
117- # building block modules
118-
119- class Block (nn .Module ):
120- def __init__ (self , dim , dim_out , groups = 8 ):
121- super ().__init__ ()
122- self .proj = nn .Conv2d (dim , dim_out , 3 , padding = 1 )
123- self .norm = nn .GroupNorm (groups , dim_out )
124- self .act = nn .SiLU ()
125-
126- def forward (self , x , scale_shift = None ):
127- x = self .proj (x )
128- x = self .norm (x )
129-
130- if exists (scale_shift ):
131- scale , shift = scale_shift
132- x = x * (scale + 1 ) + shift
133-
134- x = self .act (x )
135- return x
136-
137- class ResnetBlock (nn .Module ):
138- def __init__ (self , dim , dim_out , * , time_emb_dim = None , groups = 8 ):
139- super ().__init__ ()
140- self .mlp = nn .Sequential (
141- nn .SiLU (),
142- nn .Linear (time_emb_dim , dim_out * 2 )
143- ) if exists (time_emb_dim ) else None
144-
145- self .block1 = Block (dim , dim_out , groups = groups )
146- self .block2 = Block (dim_out , dim_out , groups = groups )
147- self .res_conv = nn .Conv2d (dim , dim_out , 1 ) if dim != dim_out else nn .Identity ()
148-
149- def forward (self , x , time_emb = None ):
150-
151- scale_shift = None
152- if exists (self .mlp ) and exists (time_emb ):
153- time_emb = self .mlp (time_emb )
154- time_emb = rearrange (time_emb , 'b c -> b c 1 1' )
155- scale_shift = time_emb .chunk (2 , dim = 1 )
156-
157- h = self .block1 (x , scale_shift = scale_shift )
158-
159- h = self .block2 (h )
160-
161- return h + self .res_conv (x )
162-
16395class LinearAttention (nn .Module ):
164- def __init__ (self , dim , heads = 4 , dim_head = 32 ):
96+ def __init__ (
97+ self ,
98+ dim ,
99+ heads = 4 ,
100+ dim_head = 32
101+ ):
165102 super ().__init__ ()
166103 self .scale = dim_head ** - 0.5
167104 self .heads = heads
@@ -182,7 +119,6 @@ def forward(self, x):
182119 k = k .softmax (dim = - 1 )
183120
184121 q = q * self .scale
185- v = v / (h * w )
186122
187123 context = torch .einsum ('b h d n, b h e n -> b h d e' , k , v )
188124
@@ -191,7 +127,12 @@ def forward(self, x):
191127 return self .to_out (out )
192128
193129class Attention (nn .Module ):
194- def __init__ (self , dim , heads = 4 , dim_head = 32 ):
130+ def __init__ (
131+ self ,
132+ dim ,
133+ heads = 4 ,
134+ dim_head = 32
135+ ):
195136 super ().__init__ ()
196137 self .scale = dim_head ** - 0.5
197138 self .heads = heads
@@ -212,6 +153,27 @@ def forward(self, x):
212153 out = rearrange (out , 'b h (x y) d -> b (h d) x y' , x = h , y = w )
213154 return self .to_out (out )
214155
156+ class FiLM (nn .Module ):
157+ def __init__ (
158+ self ,
159+ dim ,
160+ hidden_dim
161+ ):
162+ super ().__init__ ()
163+ self .net = nn .Sequential (
164+ nn .Linear (dim , hidden_dim * 4 ),
165+ nn .SiLU (),
166+ nn .Linear (hidden_dim * 4 , hidden_dim * 2 )
167+ )
168+
169+ nn .init .zeros_ (self .net [- 1 ].weight )
170+ nn .init .zeros_ (self .net [- 1 ].bias )
171+
172+ def forward (self , conditions , hiddens ):
173+ scale , shift = self .net (conditions ).chunk (2 , dim = - 1 )
174+ scale , shift = map (lambda t : rearrange (t , 'b d -> b 1 d' ), (scale , shift ))
175+ return hiddens * (scale + 1 ) + shift
176+
215177# model
216178
217179class RIN (nn .Module ):
@@ -352,7 +314,7 @@ def beta_linear_log_snr(t):
352314 return - log (expm1 (1e-4 + 10 * (t ** 2 )))
353315
354316def alpha_cosine_log_snr (t , s = 0.008 ):
355- return - log ((torch .cos ((t + s ) / (1 + s ) * math .pi * 0.5 ) ** - 2 ) - 1 , eps = 1e-5 ) # not sure if this accounts for beta being clipped to 0.999 in discrete version
317+ return - log ((torch .cos ((t + s ) / (1 + s ) * math .pi * 0.5 ) ** - 2 ) - 1 , eps = 1e-5 )
356318
357319def gamma_sigmoid_log_snr (t , start = - 3 , end = 3 , tau = 1 , clamp_min = 1e-5 ):
358320 v_start = torch .tensor (start / tau ).sigmoid ()
0 commit comments