Skip to content

Commit 1e52c60

Browse files
committed
update gradio
1 parent 2ea4149 commit 1e52c60

File tree

2 files changed

+18
-15
lines changed

2 files changed

+18
-15
lines changed

cosyvoice/flow/flow_matching.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import threading
1415
import torch
1516
import torch.nn.functional as F
1617
from matcha.models.components.flow_matching import BASECFM
@@ -30,6 +31,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
3031
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
3132
# Just change the architecture of the estimator here
3233
self.estimator = estimator
34+
self.lock = threading.Lock()
3335

3436
@torch.inference_mode()
3537
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, flow_cache=torch.zeros(1, 80, 0, 2)):
@@ -123,20 +125,21 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
123125
if isinstance(self.estimator, torch.nn.Module):
124126
return self.estimator.forward(x, mask, mu, t, spks, cond)
125127
else:
126-
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
127-
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
128-
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
129-
self.estimator.set_input_shape('t', (2,))
130-
self.estimator.set_input_shape('spks', (2, 80))
131-
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
132-
# run trt engine
133-
self.estimator.execute_v2([x.contiguous().data_ptr(),
134-
mask.contiguous().data_ptr(),
135-
mu.contiguous().data_ptr(),
136-
t.contiguous().data_ptr(),
137-
spks.contiguous().data_ptr(),
138-
cond.contiguous().data_ptr(),
139-
x.data_ptr()])
128+
with self.lock:
129+
self.estimator.set_input_shape('x', (2, 80, x.size(2)))
130+
self.estimator.set_input_shape('mask', (2, 1, x.size(2)))
131+
self.estimator.set_input_shape('mu', (2, 80, x.size(2)))
132+
self.estimator.set_input_shape('t', (2,))
133+
self.estimator.set_input_shape('spks', (2, 80))
134+
self.estimator.set_input_shape('cond', (2, 80, x.size(2)))
135+
# run trt engine
136+
self.estimator.execute_v2([x.contiguous().data_ptr(),
137+
mask.contiguous().data_ptr(),
138+
mu.contiguous().data_ptr(),
139+
t.contiguous().data_ptr(),
140+
spks.contiguous().data_ptr(),
141+
cond.contiguous().data_ptr(),
142+
x.data_ptr()])
140143
return x
141144

142145
def compute_loss(self, x1, mask, mu, spks=None, cond=None):

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ conformer==0.3.2
44
deepspeed==0.14.2; sys_platform == 'linux'
55
diffusers==0.27.2
66
gdown==5.1.0
7-
gradio==4.32.2
7+
gradio==5.4.0
88
grpcio==1.57.0
99
grpcio-tools==1.57.0
1010
huggingface-hub==0.25.2

0 commit comments

Comments
 (0)