Skip to content

Commit 44c960f

Browse files
authored
Update benchmark.py (#123)
1 parent a74101f commit 44c960f

File tree

6 files changed

+68
-31
lines changed

6 files changed

+68
-31
lines changed

.github/workflows/testing.yml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,10 @@ jobs:
6868
run: poetry lock --no-cache
6969

7070
- name: Install project dependencies
71-
run: poetry install --no-root
71+
run: poetry install --with dev
7272

7373
- name: Run DLC Live Tests
7474
run: poetry run dlc-live-test --nodisplay
75+
76+
- name: Run Functional Benchmark Test
77+
run: poetry run pytest tests/test_benchmark_script.py

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ benchmarking/results*
88
**DS_Store*
99
*vscode*
1010

11+
**/__MACOSX/
12+
1113
# Byte-compiled / optimized / DLL files
1214
__pycache__/
1315
*.py[cod]

dlclive/benchmark.py

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -37,42 +37,38 @@
3737

3838
def download_benchmarking_data(
3939
target_dir=".",
40-
url="http://deeplabcut.rowland.harvard.edu/datasets/dlclivebenchmark.tar.gz",
40+
url="https://huggingface.co/datasets/mwmathis/DLCspeed_benchmarking/resolve/main/Data-DLC-live-benchmark.zip",
4141
):
4242
"""
43-
Downloads a DeepLabCut-Live benchmarking Data (videos & DLC models).
43+
Downloads and extracts DeepLabCut-Live benchmarking data (videos & DLC models).
4444
"""
45+
import os
4546
import urllib.request
46-
import tarfile
4747
from tqdm import tqdm
48+
import zipfile
4849

49-
def show_progress(count, block_size, total_size):
50-
pbar.update(block_size)
51-
52-
def tarfilenamecutting(tarf):
53-
"""' auxfun to extract folder path
54-
ie. /xyz-trainsetxyshufflez/
55-
"""
56-
for memberid, member in enumerate(tarf.getmembers()):
57-
if memberid == 0:
58-
parent = str(member.path)
59-
l = len(parent) + 1
60-
if member.path.startswith(parent):
61-
member.path = member.path[l:]
62-
yield member
63-
64-
response = urllib.request.urlopen(url)
65-
print(
66-
"Downloading the benchmarking data from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format(
67-
url
68-
)
69-
)
70-
total_size = int(response.getheader("Content-Length"))
71-
pbar = tqdm(unit="B", total=total_size, position=0)
72-
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
73-
with tarfile.open(filename, mode="r:gz") as tar:
74-
tar.extractall(target_dir, members=tarfilenamecutting(tar))
50+
# Avoid nested folder issue
51+
if os.path.basename(os.path.normpath(target_dir)) == "Data-DLC-live-benchmark":
52+
target_dir = os.path.dirname(os.path.normpath(target_dir))
53+
os.makedirs(target_dir, exist_ok=True) # Ensure target directory exists
54+
55+
zip_path = os.path.join(target_dir, "Data-DLC-live-benchmark.zip")
56+
57+
if os.path.exists(zip_path):
58+
print(f"{zip_path} already exists. Skipping download.")
59+
else:
60+
def show_progress(count, block_size, total_size):
61+
pbar.update(block_size)
62+
63+
print(f"Downloading the benchmarking data from {url} ...")
64+
pbar = tqdm(unit="B", total=0, position=0, desc="Downloading")
65+
66+
filename, _ = urllib.request.urlretrieve(url, filename=zip_path, reporthook=show_progress)
67+
pbar.close()
7568

69+
print(f"Extracting {zip_path} to {target_dir} ...")
70+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
71+
zip_ref.extractall(target_dir)
7672

7773
def get_system_info() -> dict:
7874
""" Return summary info for system running benchmark

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ torch = ">=1.10,<3.0"
3636
dlclibrary = ">=0.0.6"
3737
pandas = "^1.3"
3838
tables = "^3.6"
39+
pytest = "^8.0"
3940

4041
# OS-specific TensorFlow packages
4142
tensorflow = [
@@ -44,7 +45,6 @@ tensorflow = [
4445
]
4546
tensorflow-macos = { version = ">=2.7.0,<2.12", markers = "sys_platform == 'darwin'" }
4647

47-
4848
[tool.poetry.group.dev.dependencies]
4949

5050
[build-system]

pytest.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[pytest]
2+
markers =
3+
functional: functional tests

tests/test_benchmark_script.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import os
2+
import glob
3+
import pathlib
4+
import pytest
5+
from dlclive import benchmark_videos, download_benchmarking_data
6+
7+
@pytest.mark.functional
8+
def test_benchmark_script_runs(tmp_path):
9+
datafolder = tmp_path / "Data-DLC-live-benchmark"
10+
download_benchmarking_data(str(datafolder))
11+
12+
dog_models = glob.glob(str(datafolder / "dog" / "*[!avi]"))
13+
dog_video = glob.glob(str(datafolder / "dog" / "*.avi"))[0]
14+
mouse_models = glob.glob(str(datafolder / "mouse_lick" / "*[!avi]"))
15+
mouse_video = glob.glob(str(datafolder / "mouse_lick" / "*.avi"))[0]
16+
17+
out_dir = tmp_path / "results"
18+
out_dir.mkdir(exist_ok=True)
19+
20+
pixels = [100, 400] #[2500, 10000]
21+
n_frames = 5
22+
23+
for m in dog_models:
24+
print(f"Running dog model: {m}")
25+
result = benchmark_videos(m, dog_video, output=str(out_dir), n_frames=n_frames, pixels=pixels)
26+
print("Dog model result:", result)
27+
28+
for m in mouse_models:
29+
print(f"Running mouse model: {m}")
30+
result = benchmark_videos(m, mouse_video, output=str(out_dir), n_frames=n_frames, pixels=pixels)
31+
print("Mouse model result:", result)
32+
33+
assert any(out_dir.iterdir())

0 commit comments

Comments
 (0)