File tree Expand file tree Collapse file tree 1 file changed +15
-12
lines changed
src/f5_tts/model/backbones Expand file tree Collapse file tree 1 file changed +15
-12
lines changed Original file line number Diff line number Diff line change @@ -238,18 +238,21 @@ def get_input_embed(
238238 audio_mask : bool ["b n" ] | None = None , # noqa: F722
239239 ):
240240 if self .text_uncond is None or self .text_cond is None or not cache :
241- batch = x .shape [0 ]
242- seq_lens = audio_mask .sum (dim = 1 )
243- text_embed_list = []
244- for i in range (batch ):
245- text_embed_i = self .text_embed (
246- text [i ].unsqueeze (0 ),
247- seq_lens [i ].item (),
248- drop_text = drop_text ,
249- audio_mask = audio_mask ,
250- )
251- text_embed_list .append (text_embed_i [0 ])
252- text_embed = pad_sequence (text_embed_list , batch_first = True , padding_value = 0 )
241+ if audio_mask is None :
242+ text_embed = self .text_embed (text , x .shape [1 ], drop_text = drop_text , audio_mask = audio_mask )
243+ else :
244+ batch = x .shape [0 ]
245+ seq_lens = audio_mask .sum (dim = 1 )
246+ text_embed_list = []
247+ for i in range (batch ):
248+ text_embed_i = self .text_embed (
249+ text [i ].unsqueeze (0 ),
250+ seq_lens [i ].item (),
251+ drop_text = drop_text ,
252+ audio_mask = audio_mask ,
253+ )
254+ text_embed_list .append (text_embed_i [0 ])
255+ text_embed = pad_sequence (text_embed_list , batch_first = True , padding_value = 0 )
253256 if cache :
254257 if drop_text :
255258 self .text_uncond = text_embed
You can’t perform that action at this time.
0 commit comments