@@ -134,6 +134,30 @@ def inner(
134134
135135 return inner
136136
137+ # handle channel last
138+
139+ def handle_maybe_channel_last (fn ):
140+
141+ @wraps (fn )
142+ def inner (
143+ self ,
144+ x ,
145+ * args ,
146+ ** kwargs
147+ ):
148+
149+ if self .channel_last :
150+ x = rearrange (x , 'b c ... -> b ... c' )
151+
152+ out = fn (self , x , * args , ** kwargs )
153+
154+ if self .channel_last :
155+ out = rearrange (out , 'b c ... -> b ... c' )
156+
157+ return out
158+
159+ return inner
160+
137161# helpers
138162
139163def Sequential (* modules ):
@@ -164,14 +188,17 @@ class TemporalDownsample(Module):
164188 def __init__ (
165189 self ,
166190 dim ,
191+ channel_last = False ,
167192 time_dim = None
168193 ):
169194 super ().__init__ ()
170195 self .time_dim = time_dim
196+ self .channel_last = channel_last
171197
172198 self .conv = nn .Conv1d (dim , dim , kernel_size = 3 , stride = 2 , padding = 1 )
173199 init_bilinear_kernel_1d_ (self .conv )
174200
201+ @handle_maybe_channel_last
175202 @image_or_video_to_time
176203 def forward (
177204 self ,
@@ -185,14 +212,17 @@ class TemporalUpsample(Module):
185212 def __init__ (
186213 self ,
187214 dim ,
215+ channel_last = False ,
188216 time_dim = None
189217 ):
190218 super ().__init__ ()
191219 self .time_dim = time_dim
220+ self .channel_last = channel_last
192221
193222 self .conv = nn .ConvTranspose1d (dim , dim , kernel_size = 3 , stride = 2 , padding = 1 , output_padding = 1 )
194223 init_bilinear_kernel_1d_ (self .conv )
195224
225+ @handle_maybe_channel_last
196226 @image_or_video_to_time
197227 def forward (
198228 self ,
@@ -210,13 +240,15 @@ def __init__(
210240 conv2d_kernel_size = 3 ,
211241 conv1d_kernel_size = 3 ,
212242 groups = 8 ,
243+ channel_last = False ,
213244 time_dim = None
214245 ):
215246 super ().__init__ ()
216247 assert is_odd (conv2d_kernel_size )
217248 assert is_odd (conv1d_kernel_size )
218249
219250 self .time_dim = time_dim
251+ self .channel_last = channel_last
220252
221253 self .spatial_conv = nn .Sequential (
222254 nn .Conv2d (dim , dim , conv2d_kernel_size , padding = conv2d_kernel_size // 2 ),
@@ -235,6 +267,7 @@ def __init__(
235267 nn .init .zeros_ (self .proj_out .weight )
236268 nn .init .zeros_ (self .proj_out .bias )
237269
270+ @handle_maybe_channel_last
238271 def forward (
239272 self ,
240273 x ,
@@ -277,11 +310,13 @@ def __init__(
277310 prenorm = True ,
278311 residual_attn = True ,
279312 time_dim = None ,
313+ channel_last = False ,
280314 ** attn_kwargs
281315 ):
282316 super ().__init__ ()
283317
284318 self .time_dim = time_dim
319+ self .channel_last = channel_last
285320
286321 self .temporal_attns = ModuleList ([])
287322
@@ -304,6 +339,7 @@ def __init__(
304339 nn .init .zeros_ (self .proj_out .weight )
305340 nn .init .zeros_ (self .proj_out .bias )
306341
342+ @handle_maybe_channel_last
307343 def forward (
308344 self ,
309345 x ,
@@ -312,6 +348,9 @@ def forward(
312348 is_video = x .ndim == 5
313349 assert is_video ^ (exists (batch_size ) or exists (self .time_dim )), 'either a tensor of shape (batch, channels, time, height, width) is passed in, or (batch * time, channels, height, width) along with `batch_size`'
314350
351+ if self .channel_last :
352+ x = rearrange (x , 'b ... c -> b c ...' )
353+
315354 if is_video :
316355 batch_size = x .shape [0 ]
317356 x = rearrange (x , 'b c t h w -> b h w t c' )
@@ -339,6 +378,9 @@ def forward(
339378 else :
340379 x = rearrange (x , 'b h w t c -> (b t) c h w' )
341380
381+ if self .channel_last :
382+ x = rearrange (x , 'b c ... -> b ... c' )
383+
342384 return x
343385
344386# post module hook wrapper
@@ -375,7 +417,9 @@ def __init__(
375417 upsample_module_names : List [str ] = [],
376418 channels : int = 3 ,
377419 conv_inflation_kwargs : dict = dict (),
378- attn_inflation_kwargs : dict = dict ()
420+ attn_inflation_kwargs : dict = dict (),
421+ downsample_kwargs : dict = dict (),
422+ upsample_kwargs : dict = dict (),
379423 ):
380424 super ().__init__ ()
381425
@@ -421,8 +465,8 @@ def __init__(
421465
422466 self .convs = ModuleList ([ConvolutionInflationBlock (dim = shape [1 ], ** conv_inflation_kwargs ) for shape in conv_shapes ])
423467 self .attns = ModuleList ([AttentionInflationBlock (dim = shape [1 ], ** attn_inflation_kwargs ) for shape in attn_shapes ])
424- self .downsamples = ModuleList ([TemporalDownsample (dim = shape [1 ]) for shape in downsample_shapes ])
425- self .upsamples = ModuleList ([TemporalUpsample (dim = shape [1 ]) for shape in upsample_shapes ])
468+ self .downsamples = ModuleList ([TemporalDownsample (dim = shape [1 ], ** downsample_kwargs ) for shape in downsample_shapes ])
469+ self .upsamples = ModuleList ([TemporalUpsample (dim = shape [1 ], ** upsample_kwargs ) for shape in upsample_shapes ])
426470
427471 # insert all the temporal modules with hooks
428472
0 commit comments