88
99from .components import AppendChannelsPlugin , MelSpectrogram
1010from .diffusion import ARVDiffusion , ARVSampler , VDiffusion , VSampler
11- from .utils import closest_power_2 , default , downsample , groupby , randn_like , upsample
11+ from .utils import (
12+ closest_power_2 ,
13+ default ,
14+ downsample ,
15+ exists ,
16+ groupby ,
17+ randn_like ,
18+ upsample ,
19+ )
1220
1321
1422class DiffusionModel (nn .Module ):
@@ -46,6 +54,18 @@ def __init__(self):
4654 self .downsample_factor = None
4755
4856
57+ class AdapterBase (nn .Module , ABC ):
58+ """Abstract class for DiffusionAE encoder"""
59+
60+ @abstractmethod
61+ def encode (self , x : Tensor ) -> Tensor :
62+ pass
63+
64+ @abstractmethod
65+ def decode (self , x : Tensor ) -> Tensor :
66+ pass
67+
68+
4969class DiffusionAE (DiffusionModel ):
5070 """Diffusion Auto Encoder"""
5171
@@ -55,6 +75,8 @@ def __init__(
5575 channels : Sequence [int ],
5676 encoder : EncoderBase ,
5777 inject_depth : int ,
78+ latent_factor : Optional [int ] = None ,
79+ adapter : Optional [AdapterBase ] = None ,
5880 ** kwargs ,
5981 ):
6082 context_channels = [0 ] * len (channels )
@@ -68,12 +90,19 @@ def __init__(
6890 self .in_channels = in_channels
6991 self .encoder = encoder
7092 self .inject_depth = inject_depth
93+ # Optional custom latent factor and adapter
94+ self .latent_factor = default (latent_factor , self .encoder .downsample_factor )
95+ self .adapter = adapter .requires_grad_ (False ) if exists (adapter ) else None
7196
7297 def forward ( # type: ignore
7398 self , x : Tensor , with_info : bool = False , ** kwargs
7499 ) -> Union [Tensor , Tuple [Tensor , Any ]]:
100+ # Encode input to latent channels
75101 latent , info = self .encode (x , with_info = True )
76102 channels = [None ] * self .inject_depth + [latent ]
103+ # Adapt input to diffusion if adapter provided
104+ x = self .adapter .encode (x ) if exists (self .adapter ) else x
105+ # Compute diffusion loss
77106 loss = super ().forward (x , channels = channels , ** kwargs )
78107 return (loss , info ) if with_info else loss
79108
@@ -85,18 +114,20 @@ def decode(
85114 self , latent : Tensor , generator : Optional [Generator ] = None , ** kwargs
86115 ) -> Tensor :
87116 b = latent .shape [0 ]
88- length = closest_power_2 (latent .shape [2 ] * self .encoder . downsample_factor )
117+ noise_length = closest_power_2 (latent .shape [2 ] * self .latent_factor )
89118 # Compute noise by inferring shape from latent length
90119 noise = torch .randn (
91- (b , self .in_channels , length ),
120+ (b , self .in_channels , noise_length ),
92121 device = latent .device ,
93122 dtype = latent .dtype ,
94123 generator = generator ,
95124 )
96125 # Compute context from latent
97126 channels = [None ] * self .inject_depth + [latent ] # type: ignore
98127 # Decode by sampling while conditioning on latent channels
99- return super ().sample (noise , channels = channels , ** kwargs )
128+ out = super ().sample (noise , channels = channels , ** kwargs )
129+ # Decode output with adapter if provided
130+ return self .adapter .decode (out ) if exists (self .adapter ) else out
100131
101132
102133class DiffusionUpsampler (DiffusionModel ):
0 commit comments