Skip to content

Commit f8fc8ba

Browse files
authored
Update benchmark.py
1 parent 6cae578 commit f8fc8ba

File tree

1 file changed

+18
-21
lines changed

1 file changed

+18
-21
lines changed

dlclive/benchmark.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,37 +42,34 @@ def download_benchmarking_data(
4242
"""
4343
Downloads a DeepLabCut-Live benchmarking Data (videos & DLC models).
4444
"""
45+
import os
4546
import urllib.request
46-
import tarfile
4747
from tqdm import tqdm
48+
import zipfile
4849

4950
def show_progress(count, block_size, total_size):
5051
pbar.update(block_size)
5152

52-
def tarfilenamecutting(tarf):
53-
"""' auxfun to extract folder path
54-
ie. /xyz-trainsetxyshufflez/
55-
"""
56-
for memberid, member in enumerate(tarf.getmembers()):
57-
if memberid == 0:
58-
parent = str(member.path)
59-
l = len(parent) + 1
60-
if member.path.startswith(parent):
61-
member.path = member.path[l:]
62-
yield member
63-
6453
response = urllib.request.urlopen(url)
54+
total_size = int(response.getheader("Content-Length"))
55+
pbar = tqdm(unit="B", total=total_size, position=0, desc="Downloading")
56+
filename, _ = urllib.request.urlretrieve(url, filename=zip_path, reporthook=show_progress)
57+
pbar.close()
58+
59+
class DownloadProgressBar(tqdm):
60+
def update_to(self, b=1, bsize=1, tsize=None):
61+
if tsize is not None:
62+
self.total = tsize
63+
self.update(b * bsize - self.n)
64+
65+
zip_path = os.path.join(target_dir, "Data-DLC-live-benchmark.zip")
6566
print(
66-
"Downloading the benchmarking data from the DeepLabCut server @Harvard -> Go Crimson!!! {}....".format(
67-
url
68-
)
67+
f"Downloading the benchmarking data from {url} ..."
6968
)
70-
total_size = int(response.getheader("Content-Length"))
71-
pbar = tqdm(unit="B", total=total_size, position=0)
72-
filename, _ = urllib.request.urlretrieve(url, reporthook=show_progress)
73-
with tarfile.open(filename, mode="r:gz") as tar:
74-
tar.extractall(target_dir, members=tarfilenamecutting(tar))
7569

70+
print(f"Extracting {zip_path} to {target_dir} ...")
71+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
72+
zip_ref.extractall(target_dir)
7673

7774
def get_system_info() -> dict:
7875
""" Return summary info for system running benchmark

0 commit comments

Comments
 (0)