diff --git a/README.md b/README.md index 607cbbe..65c58a4 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ The weights and the model are exactly the same as in [the official Tensorflow im Install from [pip](https://pypi.org/project/pytorch-fid/): ``` -pip install pytorch-fid +pip install git+https://github.com/XavierJiezou/pytorch-fid.git ``` Requirements: diff --git a/src/pytorch_fid/fid_score.py b/src/pytorch_fid/fid_score.py index 9c8acb2..3736b01 100755 --- a/src/pytorch_fid/fid_score.py +++ b/src/pytorch_fid/fid_score.py @@ -33,6 +33,7 @@ """ import os +import glob import pathlib from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser @@ -268,9 +269,7 @@ def compute_statistics_of_path(path, model, batch_size, dims, device, num_worker m, s = f["mu"][:], f["sigma"][:] else: path = pathlib.Path(path) - files = sorted( - [file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))] - ) + files = sorted([file for ext in IMAGE_EXTENSIONS for file in glob.glob(os.path.join(path, '**/*.{}'.format(ext)), recursive=True)]) # 支持递归搜索子目录 m, s = calculate_activation_statistics( files, model, batch_size, dims, device, num_workers )