Skip to content

Commit 6db73c5

Browse files
committed
style: run pre-commit formatter
1 parent c950c22 commit 6db73c5

File tree

2 files changed

+14
-14
lines changed

2 files changed

+14
-14
lines changed

mlx_audio/sts/models/deepfilternet/model.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,7 @@ def from_pretrained(
110110
f"Missing config.json in model directory: {model_dir}"
111111
)
112112
if not weights_path.exists():
113-
raise FileNotFoundError(
114-
f"Missing model.safetensors in: {model_dir}"
115-
)
113+
raise FileNotFoundError(f"Missing model.safetensors in: {model_dir}")
116114
return cls._load_from_files(
117115
config_path=config_path,
118116
weights_path=weights_path,
@@ -123,12 +121,8 @@ def from_pretrained(
123121
hf_kwargs = {"repo_id": model_name_or_path}
124122
if subfolder:
125123
hf_kwargs["subfolder"] = subfolder
126-
config_path = Path(
127-
hf_hub_download(filename="config.json", **hf_kwargs)
128-
)
129-
weights_path = Path(
130-
hf_hub_download(filename="model.safetensors", **hf_kwargs)
131-
)
124+
config_path = Path(hf_hub_download(filename="config.json", **hf_kwargs))
125+
weights_path = Path(hf_hub_download(filename="model.safetensors", **hf_kwargs))
132126
return cls._load_from_files(
133127
config_path=config_path,
134128
weights_path=weights_path,

mlx_audio/sts/tests/test_deepfilternet.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,9 @@ def test_target_parity(self):
232232
"""
233233
import soundfile as sf
234234

235-
target_path = self.REPO_ROOT / "examples" / "denoise" / "noisey_audio_10s_target.wav"
235+
target_path = (
236+
self.REPO_ROOT / "examples" / "denoise" / "noisey_audio_10s_target.wav"
237+
)
236238
if not target_path.exists():
237239
self.skipTest(f"Target audio not found: {target_path}")
238240

@@ -250,7 +252,8 @@ def test_target_parity(self):
250252
# Correlation (actual ~0.9997, threshold 0.999)
251253
corr = float(np.corrcoef(target, mlx_out)[0, 1])
252254
self.assertGreater(
253-
corr, 0.999,
255+
corr,
256+
0.999,
254257
f"Correlation {corr:.6f} should be > 0.999",
255258
)
256259

@@ -259,14 +262,16 @@ def test_target_parity(self):
259262
error_power = float(np.mean((target - mlx_out) ** 2))
260263
ser_db = 10 * np.log10(signal_power / (error_power + 1e-10))
261264
self.assertGreater(
262-
ser_db, 25.0,
265+
ser_db,
266+
25.0,
263267
f"SER {ser_db:.1f} dB should be > 25 dB",
264268
)
265269

266270
# Mean Absolute Error (actual ~0.001, threshold 0.002)
267271
mae = float(np.mean(np.abs(target - mlx_out)))
268272
self.assertLess(
269-
mae, 2e-3,
273+
mae,
274+
2e-3,
270275
f"MAE {mae:.6f} should be < 0.002",
271276
)
272277

@@ -275,7 +280,8 @@ def test_target_parity(self):
275280
rms_mlx = float(np.sqrt(np.mean(mlx_out**2)))
276281
rms_diff_pct = abs(rms_target - rms_mlx) / (rms_target + 1e-10) * 100
277282
self.assertLess(
278-
rms_diff_pct, 1.0,
283+
rms_diff_pct,
284+
1.0,
279285
f"RMS difference {rms_diff_pct:.3f}% should be < 1%",
280286
)
281287

0 commit comments

Comments
 (0)