Skip to content

Commit a0b8fb5

Browse files
committed
runtime trtllm: minor fixes. pytorch: update text_embedding logic to correct v0 batching.
1 parent c8bfc3a commit a0b8fb5

File tree

3 files changed

+36
-21
lines changed

3 files changed

+36
-21
lines changed

src/f5_tts/model/backbones/dit.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import torch
1313
import torch.nn.functional as F
1414
from torch import nn
15+
from torch.nn.utils.rnn import pad_sequence
1516
from x_transformers.x_transformers import RotaryEmbedding
1617

1718
from f5_tts.model.modules import (
@@ -236,19 +237,30 @@ def get_input_embed(
236237
cache: bool = True,
237238
audio_mask: bool["b n"] | None = None, # noqa: F722
238239
):
239-
seq_len = x.shape[1]
240-
# TODO. modify to get text_embed one by one (to avoid misalignment when batching), as done in runtime imple.
240+
if self.text_uncond is None or self.text_cond is None or not cache:
241+
batch = x.shape[0]
242+
seq_lens = audio_mask.sum(dim=1)
243+
text_embed_list = []
244+
for i in range(batch):
245+
text_embed_i = self.text_embed(
246+
text[i].unsqueeze(0),
247+
seq_lens[i].item(),
248+
drop_text=drop_text,
249+
audio_mask=audio_mask,
250+
)
251+
text_embed_list.append(text_embed_i[0])
252+
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
253+
if cache:
254+
if drop_text:
255+
self.text_uncond = text_embed
256+
else:
257+
self.text_cond = text_embed
258+
241259
if cache:
242260
if drop_text:
243-
if self.text_uncond is None:
244-
self.text_uncond = self.text_embed(text, seq_len, drop_text=True, audio_mask=audio_mask)
245261
text_embed = self.text_uncond
246262
else:
247-
if self.text_cond is None:
248-
self.text_cond = self.text_embed(text, seq_len, drop_text=False, audio_mask=audio_mask)
249263
text_embed = self.text_cond
250-
else:
251-
text_embed = self.text_embed(text, seq_len, drop_text=drop_text, audio_mask=audio_mask)
252264

253265
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
254266

src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/f5_tts_trtllm.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,16 @@ def __init__(
4242
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, precompute_max_pos), persistent=False)
4343
self.text_blocks = nn.Sequential(*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)])
4444

45-
def forward(self, text, seq_len):
45+
def forward(self, text, seq_len, drop_text=False):
4646
text = text + 1
4747
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
4848
text = F.pad(text, (0, seq_len - text.shape[1]), value=0)
4949
if self.mask_padding:
5050
text_mask = text == 0
5151

52+
if drop_text: # cfg for text
53+
text = torch.zeros_like(text)
54+
5255
text = self.text_embed(text) # b n -> b n d
5356
text = text + self.freqs_cis[:seq_len, :]
5457
if self.mask_padding:
@@ -385,17 +388,17 @@ def sample(
385388
# get text_embed one by one to avoid misalignment
386389
text_and_drop_embedding_list = []
387390
for i in range(batch):
388-
text_and_drop_embedding_i = self.text_embedding(
389-
torch.cat(
390-
(
391-
text_pad_sequence[i].unsqueeze(0).to(self.device),
392-
torch.full((1, text_pad_sequence.shape[1]), -1, dtype=torch.int32).to(self.device),
393-
),
394-
dim=0,
395-
),
391+
text_embedding_i = self.text_embedding(
392+
text_pad_sequence[i].unsqueeze(0).to(self.device),
393+
estimated_reference_target_mel_len[i],
394+
drop_text=False,
395+
)
396+
text_embedding_drop_i = self.text_embedding(
397+
text_pad_sequence[i].unsqueeze(0).to(self.device),
396398
estimated_reference_target_mel_len[i],
399+
drop_text=True,
397400
)
398-
text_and_drop_embedding_list.extend([text_and_drop_embedding_i[0], text_and_drop_embedding_i[1]])
401+
text_and_drop_embedding_list.extend([text_embedding_i[0], text_embedding_drop_i[0]])
399402

400403
# pad separately computed text_embed to form batch with max_seq_len
401404
text_and_drop_embedding = pad_sequence(

src/f5_tts/runtime/triton_trtllm/model_repo_f5_tts/f5_tts/1/model.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def execute(self, requests):
229229
max_seq_len = min(max(estimated_reference_target_mel_len), self.max_mel_len)
230230

231231
batch = len(requests)
232-
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float16).to(self.device)
232+
mel_features = torch.zeros((batch, max_seq_len, self.n_mel_channels), dtype=torch.float32).to(self.device)
233233
for i, mel in enumerate(mel_features_list):
234234
mel_features[i, : mel.shape[1], :] = mel
235235

@@ -254,9 +254,9 @@ def execute(self, requests):
254254

255255
responses = []
256256
for i in range(batch):
257-
ref_me_len = reference_mel_len[i]
257+
ref_mel_len = reference_mel_len[i]
258258
estimated_mel_len = estimated_reference_target_mel_len[i]
259-
denoised_one_item = denoised[i, ref_me_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
259+
denoised_one_item = denoised[i, ref_mel_len:estimated_mel_len, :].unsqueeze(0).transpose(1, 2)
260260
audio = self.forward_vocoder(denoised_one_item)
261261
if reference_rms_list[i] < self.target_rms:
262262
audio = audio * reference_rms_list[i] / self.target_rms

0 commit comments

Comments
 (0)