@@ -40,6 +40,7 @@ def __init__(
4040 diffusion_sigma_data : int ,
4141 diffusion_dynamic_threshold : float ,
4242 out_channels : Optional [int ] = None ,
43+ context_channels : Optional [Sequence [int ]] = None ,
4344 use_autoencoder : bool = False ,
4445 autoencoder : Optional [AutoEncoder1d ] = None ,
4546 autoencoder_scale : float = 1.0 ,
@@ -72,6 +73,7 @@ def __init__(
7273 use_skip_scale = use_skip_scale ,
7374 use_attention_bottleneck = use_attention_bottleneck ,
7475 out_channels = out_channels ,
76+ context_channels = context_channels ,
7577 )
7678
7779 self .diffusion = Diffusion (
@@ -81,21 +83,26 @@ def __init__(
8183 dynamic_threshold = diffusion_dynamic_threshold ,
8284 )
8385
84- def forward (self , x : Tensor ) -> Tensor :
86+ def forward (self , x : Tensor , ** kwargs ) -> Tensor :
8587 if self .use_autoencoder :
8688 x = self .autoencoder_scale * self .autoencoder .encode (x ) # type: ignore
87- return self .diffusion (x )
89+ return self .diffusion (x , ** kwargs )
8890
8991 def sample (
90- self , noise : Tensor , num_steps : int , sigma_schedule : Schedule , sampler : Sampler
92+ self ,
93+ noise : Tensor ,
94+ num_steps : int ,
95+ sigma_schedule : Schedule ,
96+ sampler : Sampler ,
97+ ** kwargs
9198 ) -> Tensor :
9299 diffusion_sampler = DiffusionSampler (
93100 diffusion = self .diffusion ,
94101 sampler = sampler ,
95102 sigma_schedule = sigma_schedule ,
96103 num_steps = num_steps ,
97104 )
98- x = diffusion_sampler (noise )
105+ x = diffusion_sampler (noise , ** kwargs )
99106
100107 if self .use_autoencoder :
101108 x = (1.0 / self .autoencoder_scale ) * self .autoencoder .decode (x )
0 commit comments