1313import time
1414from pathlib import Path
1515
16- from mlx_audio .sts .models .deepfilternet import DeepFilterNetModel
17- from mlx_audio .sts .models .mossformer2_se import MossFormer2SEModel
18-
19-
20- MODEL_CLASSES = {
21- "deepfilternet" : DeepFilterNetModel ,
22- "mossformer2" : MossFormer2SEModel ,
23- }
24-
25- # Repo ID substrings to model class mapping
16+ # Repo ID substrings to model type mapping
2617REPO_HINTS = {
2718 "deepfilter" : "deepfilternet" ,
2819 "mossformer" : "mossformer2" ,
@@ -116,6 +107,8 @@ def main():
116107 start = time .time ()
117108
118109 if model_type == "deepfilternet" :
110+ from mlx_audio .sts .models .deepfilternet import DeepFilterNetModel
111+
119112 load_kwargs = {"model_name_or_path" : args .model }
120113 if args .version is not None :
121114 load_kwargs ["version" ] = args .version
@@ -125,19 +118,15 @@ def main():
125118 model = DeepFilterNetModel .from_pretrained (** load_kwargs )
126119
127120 if args .stream :
128- try :
129- model .enhance_file_streaming (str (in_path ), str (out_path ))
130- except NotImplementedError as exc :
131- raise NotImplementedError (
132- f"Streaming unavailable for { model .model_version } : { exc } "
133- ) from exc
121+ model .enhance_file_streaming (str (in_path ), str (out_path ))
134122 mode = "streaming"
135123 else :
136124 model .enhance_file (str (in_path ), str (out_path ))
137125 mode = "offline"
138126
139127 elif model_type == "mossformer2" :
140128 from mlx_audio import audio_io
129+ from mlx_audio .sts .models .mossformer2_se import MossFormer2SEModel
141130
142131 model = MossFormer2SEModel .from_pretrained (args .model )
143132 enhanced = model .enhance (str (in_path ))
0 commit comments