Skip to content

Commit a17c5ae

Browse files
committed
pytorch imple.: fix batch 1 inference from last commit
1 parent a0b8fb5 commit a17c5ae

File tree

1 file changed

+15
-12
lines changed
  • src/f5_tts/model/backbones

1 file changed

+15
-12
lines changed

src/f5_tts/model/backbones/dit.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -238,18 +238,21 @@ def get_input_embed(
238238
audio_mask: bool["b n"] | None = None, # noqa: F722
239239
):
240240
if self.text_uncond is None or self.text_cond is None or not cache:
241-
batch = x.shape[0]
242-
seq_lens = audio_mask.sum(dim=1)
243-
text_embed_list = []
244-
for i in range(batch):
245-
text_embed_i = self.text_embed(
246-
text[i].unsqueeze(0),
247-
seq_lens[i].item(),
248-
drop_text=drop_text,
249-
audio_mask=audio_mask,
250-
)
251-
text_embed_list.append(text_embed_i[0])
252-
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
241+
if audio_mask is None:
242+
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
243+
else:
244+
batch = x.shape[0]
245+
seq_lens = audio_mask.sum(dim=1)
246+
text_embed_list = []
247+
for i in range(batch):
248+
text_embed_i = self.text_embed(
249+
text[i].unsqueeze(0),
250+
seq_lens[i].item(),
251+
drop_text=drop_text,
252+
audio_mask=audio_mask,
253+
)
254+
text_embed_list.append(text_embed_i[0])
255+
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
253256
if cache:
254257
if drop_text:
255258
self.text_uncond = text_embed

0 commit comments

Comments
 (0)