88from einops_exts import rearrange_many
99from torch import Tensor , einsum
1010
11- from .utils import closest_power_2 , default , exists , groupby
11+ from .utils import closest_power_2 , default , exists , groupby , is_sequence
1212
1313"""
1414Utils
@@ -909,9 +909,11 @@ def __init__(
909909 self .use_stft = use_stft
910910 self .use_stft_context = use_stft_context
911911
912+ self .context_features = context_features
912913 context_channels_pad_length = num_layers + 1 - len (context_channels )
913914 context_channels = context_channels + [0 ] * context_channels_pad_length
914915 self .context_channels = context_channels
916+ self .context_embedding_features = context_embedding_features
915917
916918 if use_context_channels :
917919 has_context = [c > 0 for c in context_channels ]
@@ -1140,22 +1142,21 @@ def rand_bool(shape: Any, proba: float, device: Any = None) -> Tensor:
11401142 return torch .bernoulli (torch .full (shape , proba , device = device )).to (torch .bool )
11411143
11421144
1143- class UNetConditional1d (UNet1d ):
1144- """
1145- UNet1d with classifier-free guidance on the token embeddings
1146- """
1145+ class UNetCFG1d (UNet1d ):
1146+
1147+ """UNet1d with Classifier-Free Guidance"""
11471148
11481149 def __init__ (
11491150 self ,
1150- context_embedding_features : int ,
11511151 context_embedding_max_length : int ,
1152+ context_embedding_features : int ,
11521153 ** kwargs ,
11531154 ):
11541155 super ().__init__ (
11551156 context_embedding_features = context_embedding_features , ** kwargs
11561157 )
11571158 self .fixed_embedding = FixedEmbedding (
1158- context_embedding_max_length , context_embedding_features
1159+ max_length = context_embedding_max_length , features = context_embedding_features
11591160 )
11601161
11611162 def forward ( # type: ignore
@@ -1178,14 +1179,72 @@ def forward( # type: ignore
11781179 )
11791180 embedding = torch .where (batch_mask , fixed_embedding , embedding )
11801181
1181- out = super ().forward (x , time , embedding = embedding , ** kwargs )
1182-
11831182 if embedding_scale != 1.0 :
1184- # Scale conditional output using classifier-free guidance
1183+ # Compute both normal and fixed embedding outputs
1184+ out = super ().forward (x , time , embedding = embedding , ** kwargs )
11851185 out_masked = super ().forward (x , time , embedding = fixed_embedding , ** kwargs )
1186- out = out_masked + (out - out_masked ) * embedding_scale
1186+ # Scale conditional output using classifier-free guidance
1187+ return out_masked + (out - out_masked ) * embedding_scale
1188+ else :
1189+ return super ().forward (x , time , embedding = embedding , ** kwargs )
1190+
1191+
1192+ class UNetNCCA1d (UNet1d ):
1193+
1194+ """UNet1d with Noise Channel Conditioning Augmentation"""
1195+
1196+ def __init__ (self , context_features : int , ** kwargs ):
1197+ super ().__init__ (context_features = context_features , ** kwargs )
1198+ self .embedder = NumberEmbedder (features = context_features )
1199+
1200+ def forward ( # type: ignore
1201+ self ,
1202+ x : Tensor ,
1203+ time : Tensor ,
1204+ * ,
1205+ channels_list : Sequence [Tensor ],
1206+ channels_augmentation : bool = False ,
1207+ channels_scale : Union [int , Sequence [int ]] = 0 ,
1208+ ** kwargs ,
1209+ ) -> Tensor :
1210+ b , num_items = x .shape [0 ], len (channels_list )
1211+
1212+ if channels_augmentation :
1213+ # Random noise augmentation for each item
1214+ channels_scale = torch .rand (num_items , b ).to (x ) # type: ignore
1215+ for i in range (num_items ):
1216+ item = channels_list [i ]
1217+ scale = rearrange (channels_scale [i ], "b -> b 1 1" ) # type: ignore
1218+ channels_list [i ] = torch .randn_like (item ) * scale + item * (1 - scale ) # type: ignore # noqa
1219+ else :
1220+ # Expand same scale to each batch element
1221+ if is_sequence (channels_scale ):
1222+ assert_message = "len(channels_scale) must match len(channels_list)"
1223+ assert len (channels_scale ) == num_items , assert_message
1224+ else :
1225+ channels_scale = num_items * [channels_scale ] # type: ignore
1226+ channels_scale = torch .tensor (channels_scale ).to (x ) # type: ignore
1227+ channels_scale = repeat (channels_scale , "n -> n b" , b = b )
1228+
1229+ # Compute scale feature embedding
1230+ scale_embedding = self .embedder (channels_scale )
1231+ scale_embedding = reduce (scale_embedding , "n b d -> b d" , "sum" )
1232+
1233+ return super ().forward (
1234+ x = x ,
1235+ time = time ,
1236+ channels_list = channels_list ,
1237+ features = scale_embedding ,
1238+ ** kwargs ,
1239+ )
1240+
1241+
1242+ class UNetAll1d (UNetCFG1d , UNetNCCA1d ):
1243+ def __init__ (self , * args , ** kwargs ):
1244+ super ().__init__ (* args , ** kwargs )
11871245
1188- return out
1246+ def forward (self , * args , ** kwargs ): # type: ignore
1247+ return UNetCFG1d .forward (self , * args , ** kwargs )
11891248
11901249
11911250class T5Embedder (nn .Module ):
0 commit comments