@@ -227,6 +227,7 @@ class Encoder3d(nn.Module):
227227 def __init__ (self ,
228228 dim = 128 ,
229229 z_dim = 4 ,
230+ input_channels = 3 ,
230231 dim_mult = [1 , 2 , 4 , 4 ],
231232 num_res_blocks = 2 ,
232233 attn_scales = [],
@@ -245,7 +246,7 @@ def __init__(self,
245246 scale = 1.0
246247
247248 # init block
248- self .conv1 = CausalConv3d (3 , dims [0 ], 3 , padding = 1 )
249+ self .conv1 = CausalConv3d (input_channels , dims [0 ], 3 , padding = 1 )
249250
250251 # downsample blocks
251252 downsamples = []
@@ -331,6 +332,7 @@ class Decoder3d(nn.Module):
331332 def __init__ (self ,
332333 dim = 128 ,
333334 z_dim = 4 ,
335+ output_channels = 3 ,
334336 dim_mult = [1 , 2 , 4 , 4 ],
335337 num_res_blocks = 2 ,
336338 attn_scales = [],
@@ -378,7 +380,7 @@ def __init__(self,
378380 # output blocks
379381 self .head = nn .Sequential (
380382 RMS_norm (out_dim , images = False ), nn .SiLU (),
381- CausalConv3d (out_dim , 3 , 3 , padding = 1 ))
383+ CausalConv3d (out_dim , output_channels , 3 , padding = 1 ))
382384
383385 def forward (self , x , feat_cache = None , feat_idx = [0 ]):
384386 ## conv1
@@ -449,6 +451,7 @@ def __init__(self,
449451 num_res_blocks = 2 ,
450452 attn_scales = [],
451453 temperal_downsample = [True , True , False ],
454+ image_channels = 3 ,
452455 dropout = 0.0 ):
453456 super ().__init__ ()
454457 self .dim = dim
@@ -460,11 +463,11 @@ def __init__(self,
460463 self .temperal_upsample = temperal_downsample [::- 1 ]
461464
462465 # modules
463- self .encoder = Encoder3d (dim , z_dim * 2 , dim_mult , num_res_blocks ,
466+ self .encoder = Encoder3d (dim , z_dim * 2 , image_channels , dim_mult , num_res_blocks ,
464467 attn_scales , self .temperal_downsample , dropout )
465468 self .conv1 = CausalConv3d (z_dim * 2 , z_dim * 2 , 1 )
466469 self .conv2 = CausalConv3d (z_dim , z_dim , 1 )
467- self .decoder = Decoder3d (dim , z_dim , dim_mult , num_res_blocks ,
470+ self .decoder = Decoder3d (dim , z_dim , image_channels , dim_mult , num_res_blocks ,
468471 attn_scales , self .temperal_upsample , dropout )
469472
470473 def encode (self , x ):
0 commit comments