Skip to content

Commit d9a6945

Browse files
committed
formatting
1 parent bc15df2 commit d9a6945

File tree

1 file changed

+6
-6
lines changed
  • src/f5_tts/model/backbones

1 file changed

+6
-6
lines changed

src/f5_tts/model/backbones/dit.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,18 @@ def average_upsample_text_by_mask(self, text, text_mask):
6666

6767
valid_ind = torch.where(valid_mask)[0]
6868
valid_data = text[0, valid_ind, :] # [valid_len, text_dim]
69-
69+
7070
base_repeat = audio_len // valid_len
7171
remainder = audio_len % valid_len
72-
72+
7373
indices = []
7474
for j in range(valid_len):
7575
repeat_count = base_repeat + (1 if j >= valid_len - remainder else 0)
7676
indices.extend([j] * repeat_count)
77-
77+
7878
indices = torch.tensor(indices[:audio_len], device=text.device, dtype=torch.long)
7979
upsampled = valid_data[indices] # [audio_len, text_dim]
80-
80+
8181
upsampled_text[0, :audio_len, :] = upsampled
8282

8383
return upsampled_text
@@ -245,7 +245,7 @@ def get_input_embed(
245245
text_embed = self.text_embed(text, x.shape[1], drop_text=drop_text)
246246
else:
247247
batch = x.shape[0]
248-
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
248+
seq_lens = audio_mask.sum(dim=1) # Calculate the actual sequence length for each sample
249249
text_embed_list = []
250250
for i in range(batch):
251251
text_embed_i = self.text_embed(
@@ -325,4 +325,4 @@ def forward(
325325
x = self.norm_out(x, t)
326326
output = self.proj_out(x)
327327

328-
return output
328+
return output

0 commit comments

Comments
 (0)