Skip to content

Commit c950c22

Browse files
kylehowellsclaude
andcommitted
sts generate: simplify CLI with lazy imports and remove dead code
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 9c2860e commit c950c22

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

mlx_audio/sts/generate.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,7 @@
1313
import time
1414
from pathlib import Path
1515

16-
from mlx_audio.sts.models.deepfilternet import DeepFilterNetModel
17-
from mlx_audio.sts.models.mossformer2_se import MossFormer2SEModel
18-
19-
20-
MODEL_CLASSES = {
21-
"deepfilternet": DeepFilterNetModel,
22-
"mossformer2": MossFormer2SEModel,
23-
}
24-
25-
# Repo ID substrings to model class mapping
16+
# Repo ID substrings to model type mapping
2617
REPO_HINTS = {
2718
"deepfilter": "deepfilternet",
2819
"mossformer": "mossformer2",
@@ -116,6 +107,8 @@ def main():
116107
start = time.time()
117108

118109
if model_type == "deepfilternet":
110+
from mlx_audio.sts.models.deepfilternet import DeepFilterNetModel
111+
119112
load_kwargs = {"model_name_or_path": args.model}
120113
if args.version is not None:
121114
load_kwargs["version"] = args.version
@@ -125,19 +118,15 @@ def main():
125118
model = DeepFilterNetModel.from_pretrained(**load_kwargs)
126119

127120
if args.stream:
128-
try:
129-
model.enhance_file_streaming(str(in_path), str(out_path))
130-
except NotImplementedError as exc:
131-
raise NotImplementedError(
132-
f"Streaming unavailable for {model.model_version}: {exc}"
133-
) from exc
121+
model.enhance_file_streaming(str(in_path), str(out_path))
134122
mode = "streaming"
135123
else:
136124
model.enhance_file(str(in_path), str(out_path))
137125
mode = "offline"
138126

139127
elif model_type == "mossformer2":
140128
from mlx_audio import audio_io
129+
from mlx_audio.sts.models.mossformer2_se import MossFormer2SEModel
141130

142131
model = MossFormer2SEModel.from_pretrained(args.model)
143132
enhanced = model.enhance(str(in_path))

0 commit comments

Comments
 (0)