@@ -52,6 +52,7 @@ def main(args):
5252 config_allow_defaults = True ,
5353 epoch = args .epoch ,
5454 mask_only = args .no_df_stage ,
55+ device = args .device ,
5556 )
5657 suffix = suffix if args .suffix else None
5758 if args .output_dir is None :
@@ -76,7 +77,12 @@ def main(args):
7677 progress = (i + 1 ) / n_samples * 100
7778 t0 = time .time ()
7879 audio = enhance (
79- model , df_state , audio , pad = args .compensate_delay , atten_lim_db = args .atten_lim
80+ model ,
81+ df_state ,
82+ audio ,
83+ pad = args .compensate_delay ,
84+ atten_lim_db = args .atten_lim ,
85+ device = args .device ,
8086 )
8187 t1 = time .time ()
8288 t_audio = audio .shape [- 1 ] / df_sr
@@ -107,6 +113,7 @@ def init_df(
107113 epoch : Union [str , int , None ] = "best" ,
108114 default_model : str = DEFAULT_MODEL ,
109115 mask_only : bool = False ,
116+ device : Optional [str ] = None ,
110117) -> Tuple [nn .Module , DF , str , int ]:
111118 """Initializes and loads config, model and deep filtering state.
112119
@@ -119,6 +126,8 @@ def init_df(
119126 config_allow_defaults (bool): Whether to allow initializing new config values with defaults.
120127 epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, `<int>`, and `none`.
121128 `none` disables checkpoint loading. Defaults to `best`.
129+ device (str): Set the torch compute device.
130+ If None, will automatically choose an available backend. (Optional)
122131
123132 Returns:
124133 model (nn.Modules): Intialized model, moved to GPU if available.
@@ -177,17 +186,19 @@ def init_df(
177186 logger .error ("Could not find a checkpoint" )
178187 exit (1 )
179188 logger .debug (f"Loaded checkpoint from epoch { epoch } " )
180- model = model .to (get_device ())
189+
190+ compute_device = get_device (device = device )
191+ model = model .to (compute_device )
181192 # Set suffix to model name
182193 suffix = os .path .basename (os .path .abspath (model_base_dir ))
183194 if post_filter :
184195 suffix += "_pf"
185- logger .info ("Running on device {}" .format (get_device () ))
196+ logger .info ("Running on device {}" .format (compute_device ))
186197 logger .info ("Model loaded" )
187198 return model , df_state , suffix , epoch
188199
189200
190- def df_features (audio : Tensor , df : DF , nb_df : int , device = None ) -> Tuple [Tensor , Tensor , Tensor ]:
201+ def df_features (audio : Tensor , df : DF , nb_df : int , device : Optional [ torch . device ] = None ) -> Tuple [Tensor , Tensor , Tensor ]:
191202 spec = df .analysis (audio .numpy ()) # [C, Tf] -> [C, Tf, F]
192203 a = get_norm_alpha (False )
193204 erb_fb = df .erb_widths ()
@@ -205,7 +216,12 @@ def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor,
205216
206217@torch .no_grad ()
207218def enhance (
208- model : nn .Module , df_state : DF , audio : Tensor , pad = True , atten_lim_db : Optional [float ] = None
219+ model : nn .Module ,
220+ df_state : DF ,
221+ audio : Tensor ,
222+ pad = True ,
223+ atten_lim_db : Optional [float ] = None ,
224+ device : Optional [str ] = None ,
209225):
210226 """Enhance a single audio given a preloaded model and DF state.
211227
@@ -216,23 +232,30 @@ def enhance(
216232 pad (bool): Pad the audio to compensate for delay due to STFT/ISTFT.
217233 atten_lim_db (float): An optional noise attenuation limit in dB. E.g. an attenuation limit of
218234 12 dB only suppresses 12 dB and keeps the remaining noise in the resulting audio.
235+ device (str): Set the torch compute device.
236+ If None, will automatically choose an available backend. (Optional)
219237
220238 Returns:
221239 enhanced audio (Tensor): If `pad` was `False` of shape [C, T'] where T'<T slightly delayed due to STFT.
222240 If `pad` was `True` it has the same shape as the input.
223241 """
242+ compute_device = get_device (device = device )
243+ model .to (compute_device )
224244 model .eval ()
245+
225246 bs = audio .shape [0 ]
226247 if hasattr (model , "reset_h0" ):
227- model .reset_h0 (batch_size = bs , device = get_device () )
248+ model .reset_h0 (batch_size = bs , device = compute_device )
228249 orig_len = audio .shape [- 1 ]
229250 n_fft , hop = 0 , 0
230251 if pad :
231252 n_fft , hop = df_state .fft_size (), df_state .hop_size ()
232253 # Pad audio to compensate for the delay due to the real-time STFT implementation
233254 audio = F .pad (audio , (0 , n_fft ))
234255 nb_df = getattr (model , "nb_df" , getattr (model , "df_bins" , ModelParams ().nb_df ))
235- spec , erb_feat , spec_feat = df_features (audio , df_state , nb_df , device = get_device ())
256+ spec , erb_feat , spec_feat = df_features (
257+ audio , df_state , nb_df , device = compute_device
258+ )
236259 enhanced = model (spec .clone (), erb_feat , spec_feat )[0 ].cpu ()
237260 enhanced = as_complex (enhanced .squeeze (1 ))
238261 if atten_lim_db is not None and abs (atten_lim_db ) > 0 :
@@ -375,6 +398,11 @@ def run():
375398 help = "Don't add the model suffix to the enhanced audio files" ,
376399 )
377400 parser .add_argument ("--no-df-stage" , action = "store_true" )
401+ parser .add_argument (
402+ "--device" ,
403+ type = str ,
404+ help = "Set the torch compute device" ,
405+ )
378406 args = parser .parse_args ()
379407 main (args )
380408
0 commit comments