@@ -104,10 +104,10 @@ def forward(self, x):
104
104
"""x --> [K x Conv2D] --> PatchMerge
105
105
106
106
Args:
107
- x : (B, T, H, W, C)
107
+ x: (B, T, H, W, C)
108
108
109
109
Returns:
110
- out : (B, T, H_new, W_new, C_out)
110
+ out: (B, T, H_new, W_new, C_out)
111
111
"""
112
112
113
113
B , T , H , W , C = x .shape
@@ -178,10 +178,10 @@ def forward(self, x):
178
178
"""x --> Upsample --> [K x Conv2D]
179
179
180
180
Args:
181
- x : (B, T, H, W, C)
181
+ x: (B, T, H, W, C)
182
182
183
183
Returns:
184
- out : (B, T, H_new, W_new, C)
184
+ out: (B, T, H_new, W_new, C)
185
185
"""
186
186
187
187
x = self .upsample (x )
@@ -286,10 +286,10 @@ def forward(self, x):
286
286
"""x --> [K x Conv2D] --> PatchMerge --> ... --> [K x Conv2D] --> PatchMerge
287
287
288
288
Args:
289
- x : (B, T, H, W, C)
289
+ x: (B, T, H, W, C)
290
290
291
291
Returns:
292
- out : (B, T, H_new, W_new, C_out)
292
+ out: (B, T, H_new, W_new, C_out)
293
293
"""
294
294
295
295
for i , (conv_block , patch_merge ) in enumerate (
@@ -400,10 +400,10 @@ def forward(self, x):
400
400
"""x --> Upsample --> [K x Conv2D] --> ... --> Upsample --> [K x Conv2D]
401
401
402
402
Args:
403
- x : Shape (B, T, H, W, C)
403
+ x: Shape (B, T, H, W, C)
404
404
405
405
Returns:
406
- out : Shape (B, T, H_new, W_new, C)
406
+ out: Shape (B, T, H_new, W_new, C)
407
407
"""
408
408
for i , (conv_block , upsample ) in enumerate (
409
409
zip (self .conv_block_list , self .upsample_list )
@@ -915,10 +915,10 @@ def get_initial_z(self, final_mem, T_out):
915
915
def forward (self , x , verbose = False ):
916
916
"""
917
917
Args:
918
- x : Shape (B, T, H, W, C)
919
- verbos : if True, print intermediate shapes
918
+ x: Shape (B, T, H, W, C)
919
+ verbose : if True, print intermediate shapes
920
920
Returns:
921
- out : The output Shape (B, T_out, H, W, C_out)
921
+ out: The output Shape (B, T_out, H, W, C_out)
922
922
"""
923
923
924
924
x = self .concat_to_tensor (x , self .input_keys )
0 commit comments