@@ -51,43 +51,38 @@ def __init__(
5151 else :
5252 self .extra_modeling = False
5353
54- def average_upsample_text_by_mask (self , text , text_mask , audio_mask ):
54+ def average_upsample_text_by_mask (self , text , text_mask ):
5555 batch , text_len , text_dim = text .shape
56+ assert batch == 1
5657
57- if audio_mask is None :
58- audio_mask = torch .ones_like (text_mask , dtype = torch .bool )
59- valid_mask = audio_mask & text_mask
60- audio_lens = audio_mask .sum (dim = 1 ) # [batch]
61- valid_lens = valid_mask .sum (dim = 1 ) # [batch]
58+ valid_mask = text_mask [0 ]
59+ audio_len = text_len
60+ valid_len = valid_mask .sum ().item ()
6261
63- upsampled_text = torch .zeros_like (text )
64-
65- for i in range (batch ):
66- audio_len = audio_lens [i ].item ()
67- valid_len = valid_lens [i ].item ()
68-
69- if valid_len == 0 :
70- continue
71-
72- valid_ind = torch .where (valid_mask [i ])[0 ]
73- valid_data = text [i , valid_ind , :] # [valid_len, text_dim]
62+ if valid_len == 0 :
63+ return torch .zeros_like (text )
7464
75- base_repeat = audio_len // valid_len
76- remainder = audio_len % valid_len
77-
78- indices = []
79- for j in range (valid_len ):
80- repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0 )
81- indices .extend ([j ] * repeat_count )
82-
83- indices = torch .tensor (indices [:audio_len ], device = text .device , dtype = torch .long )
84- upsampled = valid_data [indices ] # [audio_len, text_dim]
65+ upsampled_text = torch .zeros_like (text )
8566
86- upsampled_text [i , :audio_len , :] = upsampled
67+ valid_ind = torch .where (valid_mask )[0 ]
68+ valid_data = text [0 , valid_ind , :] # [valid_len, text_dim]
69+
70+ base_repeat = audio_len // valid_len
71+ remainder = audio_len % valid_len
72+
73+ indices = []
74+ for j in range (valid_len ):
75+ repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0 )
76+ indices .extend ([j ] * repeat_count )
77+
78+ indices = torch .tensor (indices [:audio_len ], device = text .device , dtype = torch .long )
79+ upsampled = valid_data [indices ] # [audio_len, text_dim]
80+
81+ upsampled_text [0 , :audio_len , :] = upsampled
8782
8883 return upsampled_text
8984
90- def forward (self , text : int ["b nt" ], seq_len , drop_text = False , audio_mask : bool [ "b n" ] | None = None ):
85+ def forward (self , text : int ["b nt" ], seq_len , drop_text = False ):
9186 text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
9287 text = text [:, :seq_len ] # curtail if character tokens are more than the mel spec tokens
9388 text = F .pad (text , (0 , seq_len - text .shape [1 ]), value = 0 ) # (opt.) if not self.average_upsampling:
@@ -114,7 +109,7 @@ def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool[
114109 text = self .text_blocks (text )
115110
116111 if self .average_upsampling :
117- text = self .average_upsample_text_by_mask (text , ~ text_mask , audio_mask )
112+ text = self .average_upsample_text_by_mask (text , ~ text_mask )
118113
119114 return text
120115
@@ -247,17 +242,16 @@ def get_input_embed(
247242 ):
248243 if self .text_uncond is None or self .text_cond is None or not cache :
249244 if audio_mask is None :
250- text_embed = self .text_embed (text , x .shape [1 ], drop_text = drop_text , audio_mask = audio_mask )
245+ text_embed = self .text_embed (text , x .shape [1 ], drop_text = drop_text )
251246 else :
252247 batch = x .shape [0 ]
253- seq_lens = audio_mask .sum (dim = 1 )
248+ seq_lens = audio_mask .sum (dim = 1 ) # Calculate the actual sequence length for each sample
254249 text_embed_list = []
255250 for i in range (batch ):
256251 text_embed_i = self .text_embed (
257252 text [i ].unsqueeze (0 ),
258- seq_lens [i ].item (),
253+ seq_len = seq_lens [i ].item (),
259254 drop_text = drop_text ,
260- audio_mask = audio_mask ,
261255 )
262256 text_embed_list .append (text_embed_i [0 ])
263257 text_embed = pad_sequence (text_embed_list , batch_first = True , padding_value = 0 )
@@ -331,4 +325,4 @@ def forward(
331325 x = self .norm_out (x , t )
332326 output = self .proj_out (x )
333327
334- return output
328+ return output
0 commit comments