Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ langchain==0.0.101
torch==1.13.1
torchvision==0.14.1
wget==3.2
tdqm
accelerate
addict
albumentations
Expand Down
24 changes: 19 additions & 5 deletions visual_chatgpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import numpy as np
import matplotlib.pyplot as plt
import wget
from tqdm import tqdm

VISUAL_CHATGPT_PREFIX = """Visual ChatGPT is designed to be able to assist with a wide range of text and visual related tasks, from answering simple questions to providing in-depth explanations and discussions on a wide range of topics. Visual ChatGPT is able to generate human-like text based on the input it receives, allowing it to engage in natural-sounding conversations and provide responses that are coherent and relevant to the topic at hand.

Expand Down Expand Up @@ -138,6 +139,18 @@
os.makedirs('image', exist_ok=True)


class ProgressBar:

def __init__(self, url):
self.progress_bar = None
print(f"Downloading checkpoints file from {url}")

def __call__(self, current_bytes, total_bytes, width):
if self.progress_bar is None:
self.progress_bar = tqdm(total=total_bytes, unit='B', unit_scale=True, unit_divisor=1024)#tqdm(total=total_mb, desc="MB")
self.progress_bar.update(current_bytes)


def seed_everything(seed):
random.seed(seed)
np.random.seed(seed)
Expand Down Expand Up @@ -817,7 +830,7 @@ def __init__(self, device):
def download_parameters(self):
url = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))


def show_mask(self, mask: np.ndarray,image: np.ndarray,
Expand Down Expand Up @@ -1038,12 +1051,13 @@ def __init__(self, device):
def download_parameters(self):
url = "https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth"
if not os.path.exists(self.model_checkpoint_path):
wget.download(url,out=self.model_checkpoint_path)
wget.download(url, out=self.model_checkpoint_path, bar=ProgressBar(url))
config_url = "https://raw.githubusercontent.com/IDEA-Research/GroundingDINO/main/groundingdino/config/GroundingDINO_SwinT_OGC.py"
if not os.path.exists(self.model_config_path):
wget.download(config_url,out=self.model_config_path)
def load_image(self,image_path):
# load image
wget.download(config_url, out=self.model_config_path, bar=ProgressBar(url))

def load_image(self, image_path):
# load image
image_pil = Image.open(image_path).convert("RGB") # load image

transform = T.Compose(
Expand Down