@@ -49,6 +49,11 @@ def compact_values(d: dict):
4949def l2norm (t , dim = - 1 , eps = 1e-12 ):
5050 return F .normalize (t , dim = dim , eps = eps )
5151
52+ def interpolate_1d (x , length , mode = 'bilinear' ):
53+ x = rearrange (x , 'b c t -> b c t 1' )
54+ x = F .interpolate (x , (length , 1 ), mode = mode )
55+ return rearrange (x , 'b c t 1 -> b c t' )
56+
5257# mp activations
5358# section 2.5
5459
@@ -85,6 +90,65 @@ def forward(self, x):
8590 weight = l2norm (self .weight , eps = self .eps ) / sqrt (self .fan_in )
8691 return F .linear (x , weight )
8792
93+ # forced weight normed conv2d and linear
94+ # algorithm 1 in paper
95+
96+ class Conv2d (Module ):
97+ def __init__ (
98+ self ,
99+ dim_in ,
100+ dim_out ,
101+ kernel_size ,
102+ eps = 1e-4
103+ ):
104+ super ().__init__ ()
105+ weight = torch .randn (dim_out , dim_in , kernel_size , kernel_size )
106+ self .weight = nn .Parameter (weight )
107+
108+ self .eps = eps
109+ self .fan_in = dim_in * kernel_size ** 2
110+
111+ def forward (self , x ):
112+ if self .training :
113+ with torch .no_grad ():
114+ weight , ps = pack_one (self .weight , 'o *' )
115+ normed_weight = l2norm (weight , eps = self .eps )
116+ normed_weight = unpack_one (normed_weight , ps , 'o *' )
117+ self .weight .copy_ (normed_weight )
118+
119+ weight = l2norm (self .weight , eps = self .eps ) / sqrt (self .fan_in )
120+ return F .conv2d (x , weight , padding = 'same' )
121+
122+ class Conv1d (Module ):
123+ def __init__ (
124+ self ,
125+ dim_in ,
126+ dim_out ,
127+ kernel_size ,
128+ eps = 1e-4 ,
129+ init_dirac = False
130+ ):
131+ super ().__init__ ()
132+ weight = torch .randn (dim_out , dim_in , kernel_size )
133+ self .weight = nn .Parameter (weight )
134+
135+ if init_dirac :
136+ nn .init .dirac_ (self .weight )
137+
138+ self .eps = eps
139+ self .fan_in = dim_in * kernel_size
140+
141+ def forward (self , x ):
142+ if self .training :
143+ with torch .no_grad ():
144+ weight , ps = pack_one (self .weight , 'o *' )
145+ normed_weight = l2norm (weight , eps = self .eps )
146+ normed_weight = unpack_one (normed_weight , ps , 'o *' )
147+ self .weight .copy_ (normed_weight )
148+
149+ weight = l2norm (self .weight , eps = self .eps ) / sqrt (self .fan_in )
150+ return F .conv1d (x , weight , padding = 'same' )
151+
88152# pixelnorm
89153# equation (30)
90154
@@ -183,18 +247,18 @@ def __init__(
183247 super ().__init__ ()
184248 self .time_dim = time_dim
185249 self .channel_last = channel_last
186-
187- self .conv = nn .Conv1d (dim , dim , kernel_size = 3 , stride = 2 , padding = 1 )
188- init_bilinear_kernel_1d_ (self .conv )
250+ self .conv = Conv1d (dim , dim , 3 , init_dirac = True )
189251
190252 @handle_maybe_channel_last
191253 @image_or_video_to_time
192254 def forward (
193255 self ,
194256 x
195257 ):
196- assert x .shape [- 1 ] > 1 , 'time dimension must be greater than 1 to be compressed'
258+ t = x .shape [- 1 ]
259+ assert t > 1 , 'time dimension must be greater than 1 to be compressed'
197260
261+ x = interpolate_1d (x , t // 2 )
198262 return self .conv (x )
199263
200264class MPTemporalUpsample (Module ):
@@ -207,16 +271,16 @@ def __init__(
207271 super ().__init__ ()
208272 self .time_dim = time_dim
209273 self .channel_last = channel_last
210-
211- self .conv = nn .ConvTranspose1d (dim , dim , kernel_size = 3 , stride = 2 , padding = 1 , output_padding = 1 )
212- init_bilinear_kernel_1d_ (self .conv )
274+ self .conv = Conv1d (dim , dim , 3 , init_dirac = True )
213275
214276 @handle_maybe_channel_last
215277 @image_or_video_to_time
216278 def forward (
217279 self ,
218280 x
219281 ):
282+ t = x .shape [- 1 ]
283+ x = interpolate_1d (x , t * 2 )
220284 return self .conv (x )
221285
222286# main modules
@@ -233,26 +297,23 @@ def __init__(
233297 mp_add_t = 0.3
234298 ):
235299 super ().__init__ ()
236- assert is_odd (conv2d_kernel_size )
237- assert is_odd (conv1d_kernel_size )
238-
239300 self .time_dim = time_dim
240301 self .channel_last = channel_last
241302
242303 self .spatial_conv = nn .Sequential (
243- nn . Conv2d (dim , dim , conv2d_kernel_size , padding = conv2d_kernel_size // 2 ),
304+ Conv2d (dim , dim , conv2d_kernel_size , 3 ),
244305 MPSiLU ()
245306 )
246307
247308 self .temporal_conv = nn .Sequential (
248- nn . Conv1d (dim , dim , conv1d_kernel_size , padding = conv1d_kernel_size // 2 ),
309+ Conv1d (dim , dim , conv1d_kernel_size , 3 ),
249310 MPSiLU ()
250311 )
251312
252- self .proj_out = nn .Conv1d ( dim , dim , 1 )
253-
254- nn . init . zeros_ ( self . proj_out . weight )
255- nn . init . zeros_ ( self . proj_out . bias )
313+ self .proj_out = nn .Sequential (
314+ Conv1d ( dim , dim , 1 ),
315+ Gain ( )
316+ )
256317
257318 self .residual_mp_add = MPAdd (t = mp_add_t )
258319
0 commit comments