From cfd5238df1faa694f65d79ec66e59a3351227001 Mon Sep 17 00:00:00 2001 From: benniekiss <63211101+benniekiss@users.noreply.github.com> Date: Fri, 1 Nov 2024 13:42:25 -0400 Subject: [PATCH] allow specifying compute device --- DeepFilterNet/df/enhance.py | 42 ++++++++++++++++++++++++++++++------- DeepFilterNet/df/utils.py | 16 +++++++++----- 2 files changed, 46 insertions(+), 12 deletions(-) diff --git a/DeepFilterNet/df/enhance.py b/DeepFilterNet/df/enhance.py index 43a7c6285..6411dee9a 100644 --- a/DeepFilterNet/df/enhance.py +++ b/DeepFilterNet/df/enhance.py @@ -52,6 +52,7 @@ def main(args): config_allow_defaults=True, epoch=args.epoch, mask_only=args.no_df_stage, + device=args.device, ) suffix = suffix if args.suffix else None if args.output_dir is None: @@ -76,7 +77,12 @@ def main(args): progress = (i + 1) / n_samples * 100 t0 = time.time() audio = enhance( - model, df_state, audio, pad=args.compensate_delay, atten_lim_db=args.atten_lim + model, + df_state, + audio, + pad=args.compensate_delay, + atten_lim_db=args.atten_lim, + device=args.device, ) t1 = time.time() t_audio = audio.shape[-1] / df_sr @@ -107,6 +113,7 @@ def init_df( epoch: Union[str, int, None] = "best", default_model: str = DEFAULT_MODEL, mask_only: bool = False, + device: Optional[str] = None, ) -> Tuple[nn.Module, DF, str, int]: """Initializes and loads config, model and deep filtering state. @@ -119,6 +126,8 @@ def init_df( config_allow_defaults (bool): Whether to allow initializing new config values with defaults. epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, ``, and `none`. `none` disables checkpoint loading. Defaults to `best`. + device (str): Set the torch compute device. + If `None`, will automatically choose an available backend. (Optional) Returns: model (nn.Modules): Intialized model, moved to GPU if available. @@ -177,17 +186,19 @@ def init_df( logger.error("Could not find a checkpoint") exit(1) logger.debug(f"Loaded checkpoint from epoch {epoch}") - model = model.to(get_device()) + + compute_device = get_device(device=device) + model = model.to(compute_device) # Set suffix to model name suffix = os.path.basename(os.path.abspath(model_base_dir)) if post_filter: suffix += "_pf" - logger.info("Running on device {}".format(get_device())) + logger.info("Running on device {}".format(compute_device)) logger.info("Model loaded") return model, df_state, suffix, epoch -def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, Tensor, Tensor]: +def df_features(audio: Tensor, df: DF, nb_df: int, device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor, Tensor]: spec = df.analysis(audio.numpy()) # [C, Tf] -> [C, Tf, F] a = get_norm_alpha(False) erb_fb = df.erb_widths() @@ -205,7 +216,12 @@ def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, @torch.no_grad() def enhance( - model: nn.Module, df_state: DF, audio: Tensor, pad=True, atten_lim_db: Optional[float] = None + model: nn.Module, + df_state: DF, + audio: Tensor, + pad=True, + atten_lim_db: Optional[float] = None, + device: Optional[str] = None, ): """Enhance a single audio given a preloaded model and DF state. @@ -216,15 +232,20 @@ def enhance( pad (bool): Pad the audio to compensate for delay due to STFT/ISTFT. atten_lim_db (float): An optional noise attenuation limit in dB. E.g. an attenuation limit of 12 dB only suppresses 12 dB and keeps the remaining noise in the resulting audio. + device (str): Set the torch compute device. + If `None`, will automatically choose an available backend. (Optional) Returns: enhanced audio (Tensor): If `pad` was `False` of shape [C, T'] where T' 0: @@ -375,6 +398,11 @@ def run(): help="Don't add the model suffix to the enhanced audio files", ) parser.add_argument("--no-df-stage", action="store_true") + parser.add_argument( + "--device", + type=str, + help="Set the torch compute device", + ) args = parser.parse_args() main(args) diff --git a/DeepFilterNet/df/utils.py b/DeepFilterNet/df/utils.py index cea7a9b3e..3adf3512a 100644 --- a/DeepFilterNet/df/utils.py +++ b/DeepFilterNet/df/utils.py @@ -17,13 +17,19 @@ from df.model import ModelParams -def get_device(): - s = config("DEVICE", default="", section="train") - if s == "": +def get_device(device: Optional[str] = None): + s = device or config("DEVICE", default="", section="train") + if not s: if torch.cuda.is_available(): - DEVICE = torch.device("cuda:0") - else: + DEVICE = torch.device("cuda") + elif torch.mps.is_available(): + DEVICE = torch.device("mps") + elif torch.xpu.is_available(): + DEVICE = torch.device("xpu") + elif torch.cpu.is_available(): DEVICE = torch.device("cpu") + else: + raise RuntimeError("No compute devices found") else: DEVICE = torch.device(s) return DEVICE