1111# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212# See the License for the specific language governing permissions and
1313# limitations under the License.
14+ import threading
1415import torch
1516import torch .nn .functional as F
1617from matcha .models .components .flow_matching import BASECFM
@@ -30,6 +31,7 @@ def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator:
3031 in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0 )
3132 # Just change the architecture of the estimator here
3233 self .estimator = estimator
34+ self .lock = threading .Lock ()
3335
3436 @torch .inference_mode ()
3537 def forward (self , mu , mask , n_timesteps , temperature = 1.0 , spks = None , cond = None , prompt_len = 0 , flow_cache = torch .zeros (1 , 80 , 0 , 2 )):
@@ -123,20 +125,21 @@ def forward_estimator(self, x, mask, mu, t, spks, cond):
123125 if isinstance (self .estimator , torch .nn .Module ):
124126 return self .estimator .forward (x , mask , mu , t , spks , cond )
125127 else :
126- self .estimator .set_input_shape ('x' , (2 , 80 , x .size (2 )))
127- self .estimator .set_input_shape ('mask' , (2 , 1 , x .size (2 )))
128- self .estimator .set_input_shape ('mu' , (2 , 80 , x .size (2 )))
129- self .estimator .set_input_shape ('t' , (2 ,))
130- self .estimator .set_input_shape ('spks' , (2 , 80 ))
131- self .estimator .set_input_shape ('cond' , (2 , 80 , x .size (2 )))
132- # run trt engine
133- self .estimator .execute_v2 ([x .contiguous ().data_ptr (),
134- mask .contiguous ().data_ptr (),
135- mu .contiguous ().data_ptr (),
136- t .contiguous ().data_ptr (),
137- spks .contiguous ().data_ptr (),
138- cond .contiguous ().data_ptr (),
139- x .data_ptr ()])
128+ with self .lock :
129+ self .estimator .set_input_shape ('x' , (2 , 80 , x .size (2 )))
130+ self .estimator .set_input_shape ('mask' , (2 , 1 , x .size (2 )))
131+ self .estimator .set_input_shape ('mu' , (2 , 80 , x .size (2 )))
132+ self .estimator .set_input_shape ('t' , (2 ,))
133+ self .estimator .set_input_shape ('spks' , (2 , 80 ))
134+ self .estimator .set_input_shape ('cond' , (2 , 80 , x .size (2 )))
135+ # run trt engine
136+ self .estimator .execute_v2 ([x .contiguous ().data_ptr (),
137+ mask .contiguous ().data_ptr (),
138+ mu .contiguous ().data_ptr (),
139+ t .contiguous ().data_ptr (),
140+ spks .contiguous ().data_ptr (),
141+ cond .contiguous ().data_ptr (),
142+ x .data_ptr ()])
140143 return x
141144
142145 def compute_loss (self , x1 , mask , mu , spks = None , cond = None ):
0 commit comments