File tree Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Expand file tree Collapse file tree 2 files changed +11
-4
lines changed Original file line number Diff line number Diff line change @@ -191,10 +191,17 @@ def __init__(
191191 num_groups = resnet_groups ,
192192 )
193193
194- def forward (self , x : Tensor , ** kwargs ) -> Tensor :
195- latent = self .encode (x )
194+ def forward ( # type: ignore
195+ self , x : Tensor , with_info : bool = False , ** kwargs
196+ ) -> Union [Tensor , Tuple [Tensor , Any ]]:
197+ if with_info :
198+ latent , info = self .encode (x , with_info = True )
199+ else :
200+ latent = self .encode (x )
201+
196202 context = self .to_context (latent )
197- return self .diffusion (x , context = [context ], ** kwargs )
203+ loss = self .diffusion (x , context = [context ], ** kwargs )
204+ return (loss , info ) if with_info else loss
198205
199206 def encode (
200207 self , x : Tensor , with_info : bool = False
Original file line number Diff line number Diff line change 33setup (
44 name = "audio-diffusion-pytorch" ,
55 packages = find_packages (exclude = []),
6- version = "0.0.28 " ,
6+ version = "0.0.29 " ,
77 license = "MIT" ,
88 description = "Audio Diffusion - PyTorch" ,
99 long_description_content_type = "text/markdown" ,
You can’t perform that action at this time.
0 commit comments