diff --git a/gfpgan/utils.py b/gfpgan/utils.py index 74ee5a83..09adf099 100644 --- a/gfpgan/utils.py +++ b/gfpgan/utils.py @@ -29,7 +29,7 @@ class GFPGANer(): bg_upsampler (nn.Module): The upsampler for the background. Default: None. """ - def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None): + def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None, model_dir=None): self.upscale = upscale self.bg_upsampler = bg_upsampler @@ -84,11 +84,11 @@ def __init__(self, model_path, upscale=2, arch='clean', channel_multiplier=2, bg save_ext='png', use_parse=True, device=self.device, - model_rootpath='gfpgan/weights') + model_rootpath=model_dir or 'gfpgan/weights') if model_path.startswith('https://'): model_path = load_file_from_url( - url=model_path, model_dir=os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) + url=model_path, model_dir=model_dir or os.path.join(ROOT_DIR, 'gfpgan/weights'), progress=True, file_name=None) loadnet = torch.load(model_path) if 'params_ema' in loadnet: keyname = 'params_ema'