Skip to content

Commit bc15df2

Browse files
authored
Merge pull request #1212 from QingyuLiu0521/fix/AverageUpsampling
Fix Average Upsampling conflict logic, introduced from the previous batch inference fix.
2 parents 1dcb4e1 + 9b2357a commit bc15df2

File tree

1 file changed

+29
-35
lines changed
  • src/f5_tts/model/backbones

1 file changed

+29
-35
lines changed

src/f5_tts/model/backbones/dit.py

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -51,43 +51,38 @@ def __init__(
5151
else:
5252
self.extra_modeling = False
5353

54-
def average_upsample_text_by_mask(self, text, text_mask, audio_mask):
54+
def average_upsample_text_by_mask(self, text, text_mask):
5555
batch, text_len, text_dim = text.shape
56+
assert batch == 1
5657

57-
if audio_mask is None:
58-
audio_mask = torch.ones_like(text_mask, dtype=torch.bool)
59-
valid_mask = audio_mask & text_mask
60-
audio_lens = audio_mask.sum(dim=1) # [batch]
61-
valid_lens = valid_mask.sum(dim=1) # [batch]
58+
valid_mask = text_mask[0]
59+
audio_len = text_len
60+
valid_len = valid_mask.sum().item()
6261

63-
upsampled_text = torch.zeros_like(text)
64-
65-
for i in range(batch):
66-
audio_len = audio_lens[i].item()
67-
valid_len = valid_lens[i].item()
68-
69-
if valid_len == 0:
70-
continue
71-
72-
valid_ind = torch.where(valid_mask[i])[0]
73-
valid_data = text[i, valid_ind, :] # [valid_len, text_dim]
62+
if valid_len == 0:
63+
return torch.zeros_like(text)
7464

75-
base_repeat = audio_len // valid_len
76-
remainder = audio_len % valid_len
77-
78-
indices = []
79-
for j in range(valid_len):
80-
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
81-
indices.extend([j] * repeat_count)
82-
83-
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
84-
upsampled = valid_data[indices] # [audio_len, text_dim]
65+
upsampled_text = torch.zeros_like(text)
8566

86-
upsampled_text[i, :audio_len, :] = upsampled
67+
valid_ind = torch.where(valid_mask)[0]
68+
valid_data = text[0, valid_ind, :] # [valid_len, text_dim]
69+
70+
base_repeat = audio_len // valid_len
71+
remainder = audio_len % valid_len
72+
73+
indices = []
74+
for j in range(valid_len):
75+
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
76+
indices.extend([j] * repeat_count)
77+
78+
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
79+
upsampled = valid_data[indices] # [audio_len, text_dim]
80+
81+
upsampled_text[0, :audio_len, :] = upsampled
8782

8883
return upsampled_text
8984

90-
def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool["b n"] | None = None):
85+
def forward(self, text: int["b nt"], seq_len, drop_text=False):
9186
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
9287
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
9388
text = F.pad(text, (0, seq_len - text.shape[1]), value=0) # (opt.) if not self.average_upsampling:
@@ -114,7 +109,7 @@ def forward(self, text: int["b nt"], seq_len, drop_text=False, audio_mask: bool[
114109
text = self.text_blocks(text)
115110

116111
if self.average_upsampling:
117-
text = self.average_upsample_text_by_mask(text, ~text_mask, audio_mask)
112+
text = self.average_upsample_text_by_mask(text, ~text_mask)
118113

119114
return text
120115

@@ -247,17 +242,16 @@ def get_input_embed(
247242
):
248243
if self.text_uncond is None or self.text_cond is None or not cache:
249244
if audio_mask is None:
250-
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text, audio_mask=audio_mask)
245+
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
251246
else:
252247
batch = x.shape[0]
253-
seq_lens = audio_mask.sum(dim=1)
248+
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
254249
text_embed_list = []
255250
for i in range(batch):
256251
text_embed_i = self.text_embed(
257252
text[i].unsqueeze(0),
258-
seq_lens[i].item(),
253+
seq_len=seq_lens[i].item(),
259254
drop_text=drop_text,
260-
audio_mask=audio_mask,
261255
)
262256
text_embed_list.append(text_embed_i[0])
263257
text_embed = pad_sequence(text_embed_list, batch_first=True, padding_value=0)
@@ -331,4 +325,4 @@ def forward(
331325
x = self.norm_out(x, t)
332326
output = self.proj_out(x)
333327

334-
return output
328+
return output

0 commit comments

Comments
 (0)