Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 35 additions & 7 deletions DeepFilterNet/df/enhance.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def main(args):
config_allow_defaults=True,
epoch=args.epoch,
mask_only=args.no_df_stage,
device=args.device,
)
suffix = suffix if args.suffix else None
if args.output_dir is None:
Expand All @@ -76,7 +77,12 @@ def main(args):
progress = (i + 1) / n_samples * 100
t0 = time.time()
audio = enhance(
model, df_state, audio, pad=args.compensate_delay, atten_lim_db=args.atten_lim
model,
df_state,
audio,
pad=args.compensate_delay,
atten_lim_db=args.atten_lim,
device=args.device,
)
t1 = time.time()
t_audio = audio.shape[-1] / df_sr
Expand Down Expand Up @@ -107,6 +113,7 @@ def init_df(
epoch: Union[str, int, None] = "best",
default_model: str = DEFAULT_MODEL,
mask_only: bool = False,
device: Optional[str] = None,
) -> Tuple[nn.Module, DF, str, int]:
"""Initializes and loads config, model and deep filtering state.

Expand All @@ -119,6 +126,8 @@ def init_df(
config_allow_defaults (bool): Whether to allow initializing new config values with defaults.
epoch (str): Checkpoint epoch to load. Options are `best`, `latest`, `<int>`, and `none`.
`none` disables checkpoint loading. Defaults to `best`.
device (str): Set the torch compute device.
If `None`, will automatically choose an available backend. (Optional)

Returns:
model (nn.Modules): Intialized model, moved to GPU if available.
Expand Down Expand Up @@ -177,17 +186,19 @@ def init_df(
logger.error("Could not find a checkpoint")
exit(1)
logger.debug(f"Loaded checkpoint from epoch {epoch}")
model = model.to(get_device())

compute_device = get_device(device=device)
model = model.to(compute_device)
# Set suffix to model name
suffix = os.path.basename(os.path.abspath(model_base_dir))
if post_filter:
suffix += "_pf"
logger.info("Running on device {}".format(get_device()))
logger.info("Running on device {}".format(compute_device))
logger.info("Model loaded")
return model, df_state, suffix, epoch


def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor, Tensor, Tensor]:
def df_features(audio: Tensor, df: DF, nb_df: int, device: Optional[torch.device] = None) -> Tuple[Tensor, Tensor, Tensor]:
spec = df.analysis(audio.numpy()) # [C, Tf] -> [C, Tf, F]
a = get_norm_alpha(False)
erb_fb = df.erb_widths()
Expand All @@ -205,7 +216,12 @@ def df_features(audio: Tensor, df: DF, nb_df: int, device=None) -> Tuple[Tensor,

@torch.no_grad()
def enhance(
model: nn.Module, df_state: DF, audio: Tensor, pad=True, atten_lim_db: Optional[float] = None
model: nn.Module,
df_state: DF,
audio: Tensor,
pad=True,
atten_lim_db: Optional[float] = None,
device: Optional[str] = None,
):
"""Enhance a single audio given a preloaded model and DF state.

Expand All @@ -216,23 +232,30 @@ def enhance(
pad (bool): Pad the audio to compensate for delay due to STFT/ISTFT.
atten_lim_db (float): An optional noise attenuation limit in dB. E.g. an attenuation limit of
12 dB only suppresses 12 dB and keeps the remaining noise in the resulting audio.
device (str): Set the torch compute device.
If `None`, will automatically choose an available backend. (Optional)

Returns:
enhanced audio (Tensor): If `pad` was `False` of shape [C, T'] where T'<T slightly delayed due to STFT.
If `pad` was `True` it has the same shape as the input.
"""
compute_device = get_device(device=device)
model.to(compute_device)
model.eval()

bs = audio.shape[0]
if hasattr(model, "reset_h0"):
model.reset_h0(batch_size=bs, device=get_device())
model.reset_h0(batch_size=bs, device=compute_device)
orig_len = audio.shape[-1]
n_fft, hop = 0, 0
if pad:
n_fft, hop = df_state.fft_size(), df_state.hop_size()
# Pad audio to compensate for the delay due to the real-time STFT implementation
audio = F.pad(audio, (0, n_fft))
nb_df = getattr(model, "nb_df", getattr(model, "df_bins", ModelParams().nb_df))
spec, erb_feat, spec_feat = df_features(audio, df_state, nb_df, device=get_device())
spec, erb_feat, spec_feat = df_features(
audio, df_state, nb_df, device=compute_device
)
enhanced = model(spec.clone(), erb_feat, spec_feat)[0].cpu()
enhanced = as_complex(enhanced.squeeze(1))
if atten_lim_db is not None and abs(atten_lim_db) > 0:
Expand Down Expand Up @@ -375,6 +398,11 @@ def run():
help="Don't add the model suffix to the enhanced audio files",
)
parser.add_argument("--no-df-stage", action="store_true")
parser.add_argument(
"--device",
type=str,
help="Set the torch compute device",
)
args = parser.parse_args()
main(args)

Expand Down
16 changes: 11 additions & 5 deletions DeepFilterNet/df/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@
from df.model import ModelParams


def get_device():
s = config("DEVICE", default="", section="train")
if s == "":
def get_device(device: Optional[str] = None):
s = device or config("DEVICE", default="", section="train")
if not s:
if torch.cuda.is_available():
DEVICE = torch.device("cuda:0")
else:
DEVICE = torch.device("cuda")
elif torch.mps.is_available():
DEVICE = torch.device("mps")
elif torch.xpu.is_available():
DEVICE = torch.device("xpu")
elif torch.cpu.is_available():
DEVICE = torch.device("cpu")
else:
raise RuntimeError("No compute devices found")
else:
DEVICE = torch.device(s)
return DEVICE
Expand Down