Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

**King Abdullah University of Science and Technology**

<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be)
<a href='https://minigpt-4.github.io'><img src='https://img.shields.io/badge/Project-Page-Green'></a> <a href='https://arxiv.org/abs/2304.10592'><img src='https://img.shields.io/badge/Paper-Arxiv-red'></a> <a href='https://huggingface.co/spaces/Vision-CAIR/minigpt4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue'></a> <a href='https://huggingface.co/Vision-CAIR/MiniGPT-4'><img src='https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Model-blue'></a> [![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1OK4kYsZphwt5DXchKkzMBjYF6jnkqh4R?usp=sharing) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://www.youtube.com/watch?v=__tftoxpBAw&feature=youtu.be) <a href="https://replicate.com/daanelson/minigpt-4"><img src="https://replicate.com/daanelson/minigpt-4/badge"></a>


## News
Expand Down
27 changes: 27 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
build:
gpu: true
cuda: "11.3"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_version: "3.8"
python_packages:
- "torch==1.12.1"
- "torchvision"
- "transformers==4.28.1"
- "gradio==3.28.1"
- "omegaconf==2.1.2"
- "iopath"
- "timm==0.6.13"
- "webdataset==0.2.48"
- "opencv-python==4.7.0.72"
- "tensorizer"
- "decord==0.6.0"
- "sentencepiece"

run:
- "echo 'deb [signed-by=/usr/share/keyrings/cloud.google.gpg] https://packages.cloud.google.com/apt cloud-sdk main' | tee -a /etc/apt/sources.list.d/google-cloud-sdk.list"
- "curl https://packages.cloud.google.com/apt/doc/apt-key.gpg | apt-key --keyring /usr/share/keyrings/cloud.google.gpg add -"
- "apt-get update && apt-get install google-cloud-cli"

predict: "predict.py:Predictor"
86 changes: 86 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from cog import BasePredictor, Input, Path
from minigpt4.common.config import Config
import torch
import argparse
from PIL import Image

from minigpt4.models import MiniGPT4
from minigpt4.common.registry import registry
from minigpt4.conversation.conversation import Chat, CONV_VISION

import os

# setting cache directory
os.environ["TORCH_HOME"] = "/src/model_cache"


class Predictor(BasePredictor):
def setup(self):
args = argparse.Namespace()
args.cfg_path = "/src/eval_configs/minigpt4_eval.yaml"
args.gpu_id = 0
args.options = []

config = Config(args)

model = MiniGPT4.from_config(config.model_cfg).to("cuda")
vis_processor_cfg = config.datasets_cfg.cc_sbu_align.vis_processor.train
vis_processor = registry.get_processor_class(
vis_processor_cfg.name
).from_config(vis_processor_cfg)
self.chat = Chat(model, vis_processor, device="cuda")

def predict(
self,
image: Path = Input(description="Image to discuss"),
prompt: str = Input(description="Prompt for mini-gpt4 regarding input image"),
num_beams: int = Input(
description="Number of beams for beam search decoding",
default=3,
ge=1,
le=10,
),
temperature: float = Input(
description="Temperature for generating tokens, lower = more predictable results",
default=1.0,
ge=0.01,
le=2.0,
),
top_p: float = Input(
description="Sample from the top p percent most likely tokens",
default=0.9,
ge=0.0,
le=1.0,
),
repetition_penalty: float = Input(
description="Penalty for repeated words in generated text; 1 is no penalty, values greater than 1 discourage repetition, less than 1 encourage it.",
default=1.0,
ge=0.01,
le=5,
),
max_new_tokens: int = Input(
description="Maximum number of new tokens to generate", ge=1, default=3000
),
max_length: int = Input(
description="Total length of prompt and output in tokens",
ge=1,
default=4000,
),
) -> str:
img_list = []
image = Image.open(image).convert("RGB")
with torch.inference_mode():
chat_state = CONV_VISION.copy()
self.chat.upload_img(image, chat_state, img_list)
self.chat.ask(prompt, chat_state)
answer = self.chat.answer(
conv=chat_state,
img_list=img_list,
num_beams=num_beams,
temperature=temperature,
max_new_tokens=max_new_tokens,
max_length=max_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
return answer[0]