Skip to content

Commit 2da0770

Browse files
committed
Cache test datasets
Signed-off-by: Beat Buesser <[email protected]>
1 parent ec0e3fc commit 2da0770

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

.github/workflows/ci-style-checks.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,17 @@ jobs:
2828
steps:
2929
- name: Checkout Repo
3030
uses: actions/checkout@v4
31+
3132
- name: Setup Python
3233
uses: actions/setup-python@v5
3334
with:
3435
python-version: '3.10'
36+
3537
- name: Pre-install
3638
run: |
3739
sudo apt-get update
3840
sudo apt-get -y -q install ffmpeg libavcodec-extra
41+
3942
- name: Install Dependencies
4043
run: |
4144
python -m pip install --upgrade pip setuptools wheel
@@ -45,17 +48,22 @@ jobs:
4548
pip install tensorflow==2.13.1
4649
pip install keras==2.13.1
4750
pip list
51+
4852
- name: pycodestyle
4953
run: pycodestyle --ignore=C0330,C0415,E203,E231,W503 --max-line-length=120 art
54+
5055
- name: pylint
5156
if: ${{ always() }}
5257
run: pylint --fail-under=9.6 art/
58+
5359
- name: mypy
5460
if: ${{ always() }}
5561
run: mypy art
62+
5663
- name: ruff
5764
if: ${{ always() }}
5865
run: ruff check art/ tests/ examples/
66+
5967
- name: black
6068
if: ${{ always() }}
6169
run: |

art/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,14 +1616,15 @@ def get_file(
16161616
if not os.path.exists(path_):
16171617
os.makedirs(path_)
16181618

1619+
target_path = os.path.join(path_, filename)
1620+
16191621
if extract:
1620-
extract_path = os.path.join(path_, filename)
1621-
full_path = extract_path + ".tar.gz"
1622+
full_path = target_path + ".tar.gz"
16221623
else:
1623-
full_path = os.path.join(path_, filename)
1624+
full_path = target_path
16241625

16251626
# Determine if dataset needs downloading
1626-
download = not os.path.exists(full_path)
1627+
download = not os.path.exists(target_path)
16271628

16281629
if download:
16291630
logger.info("Downloading data from %s", url)
@@ -1671,9 +1672,9 @@ def progress_bar(blocks: int = 1, block_size: int = 1, total_size: int | None =
16711672
raise
16721673

16731674
if extract:
1674-
if not os.path.exists(extract_path):
1675+
if not os.path.exists(target_path):
16751676
_extract(full_path, path_)
1676-
return extract_path
1677+
return target_path
16771678

16781679
return full_path
16791680

0 commit comments

Comments
 (0)