Skip to content

Commit d4d2f64

Browse files
committed
Cache test datasets
Signed-off-by: Beat Buesser <[email protected]>
1 parent e620b2d commit d4d2f64

File tree

2 files changed

+15
-6
lines changed

2 files changed

+15
-6
lines changed

.github/workflows/ci-style-checks.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,40 @@ jobs:
2828
steps:
2929
- name: Checkout Repo
3030
uses: actions/checkout@v4
31+
3132
- name: Setup Python
3233
uses: actions/setup-python@v5
3334
with:
3435
python-version: '3.10'
36+
3537
- name: Pre-install
3638
run: |
3739
sudo apt-get update
3840
sudo apt-get -y -q install ffmpeg libavcodec-extra
41+
3942
- name: Install Dependencies
4043
run: |
4144
python -m pip install --upgrade pip setuptools wheel
4245
pip install -q -r <(sed '/^tensorflow/d;/^keras/d' requirements_test.txt)
4346
pip install tensorflow==2.18.1
4447
pip install keras==3.10.0
4548
pip list
49+
4650
- name: pycodestyle
4751
run: pycodestyle --ignore=C0330,C0415,E203,E231,W503 --max-line-length=120 art
52+
4853
- name: pylint
4954
if: ${{ always() }}
5055
run: pylint --fail-under=9.6 art/
56+
5157
- name: mypy
5258
if: ${{ always() }}
5359
run: mypy art
60+
5461
- name: ruff
5562
if: ${{ always() }}
5663
run: ruff check art/ tests/ examples/
64+
5765
- name: black
5866
if: ${{ always() }}
5967
run: |

art/utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1600,14 +1600,15 @@ def get_file(
16001600
if not os.path.exists(path_):
16011601
os.makedirs(path_)
16021602

1603+
target_path = os.path.join(path_, filename)
1604+
16031605
if extract:
1604-
extract_path = os.path.join(path_, filename)
1605-
full_path = extract_path + ".tar.gz"
1606+
full_path = target_path + ".tar.gz"
16061607
else:
1607-
full_path = os.path.join(path_, filename)
1608+
full_path = target_path
16081609

16091610
# Determine if dataset needs downloading
1610-
download = not os.path.exists(full_path)
1611+
download = not os.path.exists(target_path)
16111612

16121613
if download:
16131614
logger.info("Downloading data from %s", url)
@@ -1655,9 +1656,9 @@ def progress_bar(blocks: int = 1, block_size: int = 1, total_size: int | None =
16551656
raise
16561657

16571658
if extract:
1658-
if not os.path.exists(extract_path):
1659+
if not os.path.exists(target_path):
16591660
_extract(full_path, path_)
1660-
return extract_path
1661+
return target_path
16611662

16621663
return full_path
16631664

0 commit comments

Comments
 (0)