@@ -142,27 +142,40 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_upsample: b
142142 self .add_temporal_upsample = add_temporal_upsample
143143 self .repeats = factor * out_channels // in_channels
144144
145+ @staticmethod
146+ def _dcae_upsample_rearrange (tensor , r1 = 1 , r2 = 2 , r3 = 2 ):
147+ """
148+ Convert (b, r1*r2*r3*c, f, h, w) -> (b, c, r1*f, r2*h, r3*w)
149+
150+ Args:
151+ tensor: Input tensor of shape (b, r1*r2*r3*c, f, h, w)
152+ r1: temporal upsampling factor
153+ r2: height upsampling factor
154+ r3: width upsampling factor
155+ """
156+ b , packed_c , f , h , w = tensor .shape
157+ factor = r1 * r2 * r3
158+ c = packed_c // factor
159+
160+ tensor = tensor .view (b , r1 , r2 , r3 , c , f , h , w )
161+ tensor = tensor .permute (0 , 4 , 5 , 1 , 6 , 2 , 7 , 3 )
162+ return tensor .reshape (b , c , f * r1 , h * r2 , w * r3 )
163+
145164 def forward (self , x : torch .Tensor ):
146165 r1 = 2 if self .add_temporal_upsample else 1
147166 h = self .conv (x )
148167 if self .add_temporal_upsample :
149- h = rearrange (h , "b (r2 r3 c) f h w -> b c f (h r2) (w r3)" , r2 = 2 , r3 = 2 )
168+ h = self . _dcae_upsample_rearrange (h , r1 = 1 , r2 = 2 , r3 = 2 )
150169 h = h [:, : h .shape [1 ] // 2 ]
151170
152171 # shortcut computation
153- shortcut = rearrange (x , "b (r2 r3 c) f h w -> b c f (h r2) (w r3)" , r2 = 2 , r3 = 2 )
172+ shortcut = self . _dcae_upsample_rearrange (x , r1 = 1 , r2 = 2 , r3 = 2 )
154173 shortcut = shortcut .repeat_interleave (repeats = self .repeats // 2 , dim = 1 )
155174
156175 else :
157- h = rearrange ( h , "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)" , r1 = r1 , r2 = 2 , r3 = 2 )
176+ h = self . _dcae_upsample_rearrange ( h , r1 = r1 , r2 = 2 , r3 = 2 )
158177 shortcut = x .repeat_interleave (repeats = self .repeats , dim = 1 )
159- shortcut = rearrange (
160- shortcut ,
161- "b (r1 r2 r3 c) f h w -> b c (f r1) (h r2) (w r3)" ,
162- r1 = r1 ,
163- r2 = 2 ,
164- r3 = 2 ,
165- )
178+ shortcut = self ._dcae_upsample_rearrange (shortcut , r1 = r1 , r2 = 2 , r3 = 2 )
166179 return h + shortcut
167180
168181
@@ -177,20 +190,40 @@ def __init__(self, in_channels: int, out_channels: int, add_temporal_downsample:
177190 self .add_temporal_downsample = add_temporal_downsample
178191 self .group_size = factor * in_channels // out_channels
179192
193+
194+ @staticmethod
195+ def _dcae_downsample_rearrange (self , tensor , r1 = 1 , r2 = 2 , r3 = 2 ):
196+ """
197+ Convert (b, c, r1*f, r2*h, r3*w) -> (b, r1*r2*r3*c, f, h, w)
198+
199+ This packs spatial/temporal dimensions into channels (opposite of upsample)
200+ """
201+ b , c , packed_f , packed_h , packed_w = tensor .shape
202+ f , h , w = packed_f // r1 , packed_h // r2 , packed_w // r3
203+
204+ tensor = tensor .view (b , c , r1 , f , r2 , h , r3 , w )
205+ tensor = tensor .permute (0 , 2 , 4 , 6 , 1 , 3 , 5 , 7 )
206+ return tensor .reshape (b , r1 * r2 * r3 * c , f , h , w )
207+
208+
180209 def forward (self , x : torch .Tensor ):
181210 r1 = 2 if self .add_temporal_downsample else 1
182211 h = self .conv (x )
183212 if self .add_temporal_downsample :
184- h = rearrange (h , "b c f (h r2) (w r3) -> b (r2 r3 c) f h w" , r2 = 2 , r3 = 2 )
213+ # h = rearrange(h, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
214+ h = self ._dcae_downsample_rearrange (h , r1 = 1 , r2 = 2 , r3 = 2 )
185215 h = torch .cat ([h , h ], dim = 1 )
186216
187217 # shortcut computation
188- shortcut = rearrange (x , "b c f (h r2) (w r3) -> b (r2 r3 c) f h w" , r2 = 2 , r3 = 2 )
218+ # shortcut = rearrange(x, "b c f (h r2) (w r3) -> b (r2 r3 c) f h w", r2=2, r3=2)
219+ shortcut = self ._dcae_downsample_rearrange (x , r1 = 1 , r2 = 2 , r3 = 2 )
189220 B , C , T , H , W = shortcut .shape
190221 shortcut = shortcut .view (B , h .shape [1 ], self .group_size // 2 , T , H , W ).mean (dim = 2 )
191222 else :
192- h = rearrange (h , "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w" , r1 = r1 , r2 = 2 , r3 = 2 )
193- shortcut = rearrange (x , "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w" , r1 = r1 , r2 = 2 , r3 = 2 )
223+ # h = rearrange(h, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
224+ h = self ._dcae_downsample_rearrange (h , r1 = r1 , r2 = 2 , r3 = 2 )
225+ # shortcut = rearrange(x, "b c (f r1) (h r2) (w r3) -> b (r1 r2 r3 c) f h w", r1=r1, r2=2, r3=2)
226+ shortcut = self ._dcae_downsample_rearrange (x , r1 = r1 , r2 = 2 , r3 = 2 )
194227 B , C , T , H , W = shortcut .shape
195228 shortcut = shortcut .view (B , h .shape [1 ], self .group_size , T , H , W ).mean (dim = 2 )
196229
0 commit comments