Skip to content

Commit e67d508

Browse files
committed
runtime trtllm: fix batch inference skipping last words in shorter sentences #1039 #1179
1 parent 6b07fb0 commit e67d508

File tree

3 files changed

+54
-30
lines changed

3 files changed

+54
-30
lines changed

src/f5_tts/runtime/triton_trtllm/patch/f5tts/model.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,20 @@
44
import sys
55
from collections import OrderedDict
66

7+
import numpy as np
78
import tensorrt as trt
89
from tensorrt_llm._common import default_net
910

1011
from ..._utils import str_dtype_to_trt
11-
from ...functional import Tensor, concat
12+
from ...functional import (
13+
Tensor,
14+
concat,
15+
constant,
16+
expand,
17+
shape,
18+
slice,
19+
unsqueeze,
20+
)
1221
from ...layers import Linear
1322
from ...module import Module, ModuleList
1423
from ...plugin import current_all_reduce_helper
@@ -27,9 +36,9 @@ def __init__(self, mel_dim, text_dim, out_dim):
2736
self.proj = Linear(mel_dim * 2 + text_dim, out_dim)
2837
self.conv_pos_embed = ConvPositionEmbedding(dim=out_dim)
2938

30-
def forward(self, x, cond):
39+
def forward(self, x, cond, mask=None):
3140
x = self.proj(concat([x, cond], dim=-1))
32-
return self.conv_pos_embed(x) + x
41+
return self.conv_pos_embed(x, mask=mask) + x
3342

3443

3544
class F5TTS(PretrainedModel):
@@ -69,10 +78,26 @@ def forward(
6978
input_lengths,
7079
scale=1.0,
7180
):
81+
if default_net().plugin_config.remove_input_padding:
82+
mask = None
83+
else:
84+
N = shape(noise, 1)
85+
B = shape(noise, 0)
86+
seq_len_2d = concat([1, N])
87+
max_position_embeddings = 4096
88+
# create position ids
89+
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
90+
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
91+
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # [B, N]
92+
tmp_input_lengths = unsqueeze(input_lengths, 1) # [B, 1]
93+
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # [B, N]
94+
mask = tmp_position_ids < tmp_input_lengths # [B, N]
95+
mask = mask.cast("int32")
96+
7297
t = self.time_embed(time)
73-
x = self.input_embed(noise, cond)
98+
x = self.input_embed(noise, cond, mask=mask)
7499
for block in self.transformer_blocks:
75-
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
100+
x = block(x, t, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask)
76101
denoise = self.proj_out(self.norm_out(x, t))
77102
denoise.mark_output("denoised", self.dtype)
78103
return denoise

src/f5_tts/runtime/triton_trtllm/patch/f5tts/modules.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
chunk,
1717
concat,
1818
constant,
19-
expand,
2019
expand_dims,
2120
expand_dims_like,
2221
expand_mask,
@@ -95,15 +94,24 @@ def __init__(self, dim, kernel_size=31, groups=16):
9594
self.conv1d2 = Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2)
9695
self.mish = Mish()
9796

98-
def forward(self, x, mask=None): # noqa: F722
97+
def forward(self, x, mask=None):
9998
if default_net().plugin_config.remove_input_padding:
10099
x = unsqueeze(x, 0)
101-
x = permute(x, [0, 2, 1])
102-
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
103-
out = permute(x, [0, 2, 1])
100+
if mask is not None:
101+
mask = mask.view(concat([shape(mask, 0), 1, shape(mask, 1)])) # [B 1 N]
102+
mask = expand_dims_like(mask, x) # [B D N]
103+
mask = cast(mask, x.dtype)
104+
x = permute(x, [0, 2, 1]) # [B D N]
105+
106+
if mask is not None:
107+
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x * mask) * mask)) * mask)
108+
else:
109+
x = self.mish(self.conv1d2(self.mish(self.conv1d1(x))))
110+
111+
x = permute(x, [0, 2, 1]) # [B N D]
104112
if default_net().plugin_config.remove_input_padding:
105-
out = squeeze(out, 0)
106-
return out
113+
x = squeeze(x, 0)
114+
return x
107115

108116

109117
class Attention(Module):
@@ -185,6 +193,7 @@ def forward(
185193
rope_cos,
186194
rope_sin,
187195
input_lengths,
196+
mask=None,
188197
c=None, # context c
189198
scale=1.0,
190199
rope=None,
@@ -283,6 +292,7 @@ def __call__(
283292
input_lengths,
284293
scale=1.0,
285294
rope=None,
295+
mask=None,
286296
) -> torch.FloatTensor:
287297
query = attn.to_q(x)
288298
key = attn.to_k(x)
@@ -295,20 +305,8 @@ def __call__(
295305
inner_dim = key.shape[-1]
296306
norm_factor = math.sqrt(attn.attention_head_size)
297307
q_scaling = 1.0 / norm_factor
298-
mask = None
299-
if not default_net().plugin_config.remove_input_padding:
300-
N = shape(x, 1)
301-
B = shape(x, 0)
302-
seq_len_2d = concat([1, N])
303-
max_position_embeddings = 4096
304-
# create position ids
305-
position_ids_buffer = constant(np.expand_dims(np.arange(max_position_embeddings).astype(np.int32), 0))
306-
tmp_position_ids = slice(position_ids_buffer, starts=[0, 0], sizes=seq_len_2d)
307-
tmp_position_ids = expand(tmp_position_ids, concat([B, N])) # BxL
308-
tmp_input_lengths = unsqueeze(input_lengths, 1) # Bx1
309-
tmp_input_lengths = expand(tmp_input_lengths, concat([B, N])) # BxL
310-
mask = tmp_position_ids < tmp_input_lengths # BxL
311-
mask = mask.cast("int32")
308+
if default_net().plugin_config.remove_input_padding:
309+
mask = None
312310

313311
if default_net().plugin_config.bert_attention_plugin:
314312
qkv = concat([query, key, value], dim=-1)
@@ -393,14 +391,15 @@ def __init__(self, dim, heads, dim_head, ff_mult=2, dropout=0.1, pe_attn_head=No
393391
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout)
394392

395393
def forward(
396-
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError
394+
self, x, t, rope_cos, rope_sin, input_lengths, scale=1.0, rope=ModuleNotFoundError, mask=None
397395
): # x: noised input, t: time embedding
398396
# pre-norm & modulation for attention input
399397
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
400398
# attention
401399
# norm ----> (2,1226,1024)
402-
attn_output = self.attn(x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale)
403-
400+
attn_output = self.attn(
401+
x=norm, rope_cos=rope_cos, rope_sin=rope_sin, input_lengths=input_lengths, scale=scale, mask=mask
402+
)
404403
# process attention output for input x
405404
if default_net().plugin_config.remove_input_padding:
406405
x = x + gate_msa * attn_output

src/f5_tts/runtime/triton_trtllm/run.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ fi
7373

7474
if [ $stage -le 7 ] && [ $stop_stage -ge 7 ]; then
7575
echo "TRT-LLM: offline decoding benchmark test"
76-
batch_size=1
76+
batch_size=2
7777
split_name=wenetspeech4tts
7878
backend_type=trt
7979
log_dir=./tests/benchmark_${model}_batch_size_${batch_size}_${split_name}_${backend_type}

0 commit comments

Comments
 (0)