|
44 | 44 | import numpy as np |
45 | 45 | import matplotlib.pyplot as plt |
46 | 46 | import wget |
| 47 | +from tqdm import tqdm |
47 | 48 |
|
48 | 49 | VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand. |
49 | 50 |
|
|
138 | 139 | os.makedirs('image', exist_ok=True) |
139 | 140 |
|
140 | 141 |
|
| 142 | +class ProgressBar: |
| 143 | + |
| 144 | + def __init__(self, url): |
| 145 | + self.progress_bar = None |
| 146 | + print(f"downloading from {url}") |
| 147 | + |
| 148 | + def __call__(self, current_bytes, total_bytes, width): |
| 149 | + if self.progress_bar is None: |
| 150 | + self.progress_bar = tqdm(total=total_bytes, unit='B', unit_scale=True, unit_divisor=1024)#tqdm(total=total_mb, desc="MB") |
| 151 | + self.progress_bar.update(current_bytes) |
| 152 | + |
| 153 | + |
141 | 154 | def seed_everything(seed): |
142 | 155 | random.seed(seed) |
143 | 156 | np.random.seed(seed) |
@@ -817,7 +830,7 @@ def __init__(self, device): |
817 | 830 | def download_parameters(self): |
818 | 831 | url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth" |
819 | 832 | if not os.path.exists(self.model_checkpoint_path): |
820 | | - wget.download(url,out=self.model_checkpoint_path) |
| 833 | + wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url)) |
821 | 834 |
|
822 | 835 |
|
823 | 836 | def show_mask(self, mask: np.ndarray,image: np.ndarray, |
@@ -1038,12 +1051,13 @@ def __init__(self, device): |
1038 | 1051 | def download_parameters(self): |
1039 | 1052 | url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth" |
1040 | 1053 | if not os.path.exists(self.model_checkpoint_path): |
1041 | | - wget.download(url,out=self.model_checkpoint_path) |
| 1054 | + wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url)) |
1042 | 1055 | config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py" |
1043 | 1056 | if not os.path.exists(self.model_config_path): |
1044 | | - wget.download(config_url,out=self.model_config_path) |
1045 | | - def load_image(self,image_path): |
1046 | | - # load image |
| 1057 | + wget.download(config_url, out=self.model_config_path, bar=ProgressBar(url)) |
| 1058 | + |
| 1059 | + def load_image(self, image_path): |
| 1060 | + # load image |
1047 | 1061 | image_pil = Image.open(image_path).convert("RGB") # load image |
1048 | 1062 |
|
1049 | 1063 | transform = T.Compose( |
|
0 commit comments