1+ """
2+ PyTorch-based Demucs benchmarking script.
3+
4+ This script uses PyTorch for inference and imports common functionality
5+ from benchmark_common to avoid code duplication.
6+ """
7+
8+ import argparse
9+ from pathlib import Path
10+
11+ import torchaudio
12+ from demucs .apply import apply_model
13+ from demucs .pretrained import get_model
14+
15+ from benchmark_common import run_benchmark , STEM_MAP , STEM_NAMES
16+
17+ DEFAULT_MODEL = 'htdemucs'
18+
19+
20+ def separate_pytorch (mixture_path , out_dir , model_name ):
21+ """Separate audio using PyTorch Demucs model."""
22+ # Load model
23+ model = get_model (model_name )
24+
25+ # Load and preprocess audio
26+ audio , rate = torchaudio .load (str (mixture_path ))
27+ if rate != 44100 :
28+ audio = torchaudio .functional .resample (audio , rate , 44100 )
29+
30+ # Normalize
31+ ref = audio .mean (0 )
32+ audio = (audio - ref .mean ()) / ref .std ()
33+
34+ # Apply model
35+ sources = apply_model (model , audio [None ])[0 ]
36+
37+ # Denormalize
38+ sources = sources * ref .std () + ref .mean ()
39+
40+ # Save stems
41+ out_dir = Path (out_dir )
42+ out_dir .mkdir (parents = True , exist_ok = True )
43+ for target_idx in range (len (STEM_NAMES )):
44+ target_name = STEM_MAP [target_idx ]
45+ out_audio = sources [target_idx ].detach ().cpu ()
46+ out_path = out_dir / f'target_{ target_idx } _{ target_name } .wav'
47+ torchaudio .save (str (out_path ), out_audio , sample_rate = 44100 )
48+
49+
50+ def main ():
51+ parser = argparse .ArgumentParser (description = 'PyTorch Demucs Benchmarking Script' )
52+ parser .add_argument ('--musdb-root' , type = Path , help = 'Path to MusDB root' , required = True )
53+ parser .add_argument ('--output-root' , type = Path , default = None , help = 'Where to store separated outputs (default: inside musdb-root)' )
54+ parser .add_argument ('--output-dir' , type = str , default = 'test-separated-pytorch' , help = 'Output directory name (default: test-separated-pytorch)' )
55+ parser .add_argument ('--json-out' , type = str , default = 'benchmark_results_pytorch.json' , help = 'Output JSON file for benchmarks' )
56+ parser .add_argument ('--force-reseparate' , action = 'store_true' , help = 'Force re-separation even if files already exist' )
57+ args = parser .parse_args ()
58+
59+ # Setup paths
60+ musdb_root = Path (args .musdb_root )
61+ output_root = Path (args .output_root ) if args .output_root else musdb_root
62+ out_dir = output_root / args .output_dir
63+ out_dir .mkdir (parents = True , exist_ok = True )
64+
65+ model_name = DEFAULT_MODEL
66+ print (f'Using model: { model_name } ' )
67+
68+ # Run benchmark using common flow
69+ run_benchmark (
70+ musdb_root = musdb_root ,
71+ out_dir = out_dir ,
72+ force_reseparate = args .force_reseparate ,
73+ json_out = args .json_out ,
74+ separate_func = separate_pytorch ,
75+ model_identifier = model_name ,
76+ # Arguments passed to separate_pytorch
77+ model_name = model_name
78+ )
79+
80+
81+ if __name__ == '__main__' :
82+ main ()
0 commit comments