Skip to content

Commit 5a03270

Browse files
committed
add tqdm download progress bar
1 parent 4b7664f commit 5a03270

File tree

2 files changed

+20
-5
lines changed

2 files changed

+20
-5
lines changed

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ langchain==0.0.101
22
torch==1.13.1
33
torchvision==0.14.1
44
wget==3.2
5+
tdqm
56
accelerate
67
addict
78
albumentations

visual_chatgpt.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import numpy as np
4545
import matplotlib.pyplot as plt
4646
import wget
47+
from tqdm import tqdm
4748

4849
VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.
4950
@@ -138,6 +139,18 @@
138139
os.makedirs('image', exist_ok=True)
139140

140141

142+
class ProgressBar:
143+
144+
def __init__(self, url):
145+
self.progress_bar = None
146+
print(f"downloading from {url}")
147+
148+
def __call__(self, current_bytes, total_bytes, width):
149+
if self.progress_bar is None:
150+
self.progress_bar = tqdm(total=total_bytes, unit='B', unit_scale=True, unit_divisor=1024)#tqdm(total=total_mb, desc="MB")
151+
self.progress_bar.update(current_bytes)
152+
153+
141154
def seed_everything(seed):
142155
random.seed(seed)
143156
np.random.seed(seed)
@@ -817,7 +830,7 @@ def __init__(self, device):
817830
def download_parameters(self):
818831
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
819832
if not os.path.exists(self.model_checkpoint_path):
820-
wget.download(url,out=self.model_checkpoint_path)
833+
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))
821834

822835

823836
def show_mask(self, mask: np.ndarray,image: np.ndarray,
@@ -1038,12 +1051,13 @@ def __init__(self, device):
10381051
def download_parameters(self):
10391052
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
10401053
if not os.path.exists(self.model_checkpoint_path):
1041-
wget.download(url,out=self.model_checkpoint_path)
1054+
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))
10421055
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
10431056
if not os.path.exists(self.model_config_path):
1044-
wget.download(config_url,out=self.model_config_path)
1045-
def load_image(self,image_path):
1046-
# load image
1057+
wget.download(config_url, out=self.model_config_path, bar=ProgressBar(url))
1058+
1059+
def load_image(self, image_path):
1060+
# load image
10471061
image_pil = Image.open(image_path).convert("RGB") # load image
10481062

10491063
transform = T.Compose(

0 commit comments

Comments
 (0)