Skip to content

Commit bdac3e9

Browse files
committed
add LLaMA-Adapter V2.1
1 parent e8180d1 commit bdac3e9

File tree

6 files changed

+257
-4
lines changed

6 files changed

+257
-4
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ This repo proposes **LLaMA-Adapter (V2)**, a lightweight adaption method for fin
1313
Try out the web demo 🤗 of LLaMA-Adapter: [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/csuhan/LLaMA-Adapter), [LLaMA-Adapter V2](http://llama-adapter.opengvlab.com/) and [ImageBind-LLM](http://imagebind-llm.opengvlab.com/).
1414

1515
## News
16+
- **[2023.10.11]** We realse **LLaMA-Adapter V2.1**, an improved version of LLaMA-Adapter V2 with stronger multi-modal reasoning performance. Check [llama_adapter_v2_multimodal7b](llama_adapter_v2_multimodal7b) for details.
1617
- **[2023.08.28]** We release quantized LLM with [OmniQuant](https://github.com/OpenGVLab/OmniQuant), which is an efficient, accurate, and omnibearing (even extremely low bit) quantization algorithm. Multimodal version is coming soon.🔥🔥🔥
1718
- **[2023.07.24]** We release **[LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory)**, an open-source toolkit for **pre-training**, **fine-tuning** and **deployment** of **Large Language Models (LLMs)** and **mutlimodal LLMs**. Please check [Alpha-VLLM/LLaMA2-Accessory](https://github.com/Alpha-VLLM/LLaMA2-Accessory) for more details!🔥🔥🔥
1819
- **[2023.07.05]** We release the pretrain/finetune code of [llama_adapter_v2_multimodal7b](https://github.com/OpenGVLab/LLaMA-Adapter/tree/main/llama_adapter_v2_multimodal7b).

llama_adapter_v2_multimodal7b/README.md

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# LLaMA-Adapter-V2 Multi-modal
22

33
## News
4+
* [Oct 11, 2023] Release LLaMA-Adapter V2.1 and evaluation on MME.
45
* [July 5, 2023] Release pre-traininig and fine-tuning codes.
56
* [May 26, 2023] Initial release.
67

@@ -37,8 +38,8 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
3738

3839
llama_dir = "/path/to/LLaMA/"
3940

40-
# choose from BIAS-7B, LORA-BIAS-7B
41-
model, preprocess = llama.load("BIAS-7B", llama_dir, device)
41+
# choose from BIAS-7B, LORA-BIAS-7B, LORA-BIAS-7B-v21
42+
model, preprocess = llama.load("BIAS-7B", llama_dir, llama_type="7B", device=device)
4243
model.eval()
4344

4445
prompt = llama.format_prompt("Please introduce this painting.")
@@ -55,6 +56,8 @@ The output will look like the following:
5556
The painting features a cute white lama, or llama, standing on a wooden floor. The llama is holding a variety of tools and accessories, such as a paintbrush, a pencil, a ruler, a pair of scissors, and a paint can. The llama is dressed in a suit, which adds a touch of sophistication to the scene. The painting is a creative and whimsical representation of a person or animal holding various tools and accessories, making it an interesting and unique piece of art.
5657
```
5758

59+
## Evaluation
60+
Check [eval.md](./docs/eval.md) for details.
5861

5962
## Online demo
6063

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
# Evaluation on MME Benchmark
2+
3+
[MME](https://github.com/BradyFU/Awesome-Multimodal-Large-Language-Models/tree/Evaluation) is a comprehensive evaluation benchmark for multimodal large language models. It measures both perception and cognition abilities on a total of 14 subtasks, including existence, count, position, color, poster, celebrity, scene, landmark, artwork, OCR, commonsense reasoning, numerical calculation, text translation, and code reasoning.
4+
5+
## Setup & Evaluation
6+
7+
1. Download MME datasets and `eval_tool` from the [MME repo](https://github.com/bradyfu/awesome-multimodal-large-language-models#our-mllm-works), and put them under `MME_Benchmark_release_version`. Now the folder structure will be:
8+
```
9+
MME_Benchmark_release_version
10+
├── artwork
11+
├── celebrity
12+
├── code_reasoning
13+
├── color
14+
├── commonsense_reasoning
15+
├── count
16+
├── eval_tool
17+
│ ├── calculation.py
18+
│ ├── LaVIN
19+
│ └── Your_Results
20+
├── existence
21+
├── landmark
22+
├── numerical_calculation
23+
├── OCR
24+
├── position
25+
├── posters
26+
├── scene
27+
└── text_translation
28+
```
29+
2. Generate MME results using: `python util/evaluate_mme.py --pretrained_path [MODEL_PATH] --llama_path [LLAMA_DIR] --output_path [RESULT_FILE_PATH]`
30+
3. Evaluate LLaMA-Adapter V2.1 with MME's eval_tool: `python MME_Benchmark_release_version/eval_tool/calculation.py --results_dir [RESULT_FILE_PATH]`
31+
32+
## Results
33+
34+
* **LLaMA-Adapter V2.1**
35+
36+
```
37+
=========== Perception ===========
38+
total score: 1326.0875953396435
39+
40+
existence score: 185.0
41+
count score: 133.33333333333331
42+
position score: 56.666666666666664
43+
color score: 118.33333333333334
44+
posters score: 147.9591836734694
45+
celebrity score: 134.70588235294116
46+
scene score: 156.25
47+
landmark score: 167.8391959798995
48+
artwork score: 123.5
49+
OCR score: 102.5
50+
51+
52+
=========== Cognition ===========
53+
total score: 356.42857142857144
54+
55+
commonsense_reasoning score: 106.42857142857144
56+
numerical_calculation score: 47.5
57+
text_translation score: 112.5
58+
code_reasoning score: 90.0
59+
60+
```

llama_adapter_v2_multimodal7b/docs/train.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ import os
6969
from llama.llama_adapter import LLaMA_adapter
7070
import util.misc as misc
7171
import util.extract_adapter_from_checkpoint as extract
72+
from PIL import Image
73+
import cv2
74+
import torch
75+
import llama
7276

7377
device = "cuda" if torch.cuda.is_available() else "cpu"
7478

llama_adapter_v2_multimodal7b/llama/llama_adapter.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -279,14 +279,15 @@ def generate(
279279
"BIAS-7B": "https://github.com/OpenGVLab/LLaMA-Adapter/releases/download/v.2.0.0/7fa55208379faf2dd862565284101b0e4a2a72114d6490a95e432cf9d9b6c813_BIAS-7B.pth",
280280
"LORA-BIAS-7B": "https://github.com/OpenGVLab/LLaMA-Adapter/releases/download/v.2.0.0/1bcbffc43484332672092e0024a8699a6eb5f558161aebf98a7c6b1db67224d1_LORA-BIAS-7B.pth",
281281
"CAPTION-7B": "https://github.com/OpenGVLab/LLaMA-Adapter/releases/download/v.2.0.0/5088aeb63a89746b90bcfd5cb819e1c7411b2771b267c6d131ce73e250a8abf0_CAPTION-7B.pth",
282+
"LORA-BIAS-7B-v21": "https://github.com/OpenGVLab/LLaMA-Adapter/releases/download/v.2.1.0/427dbc27bf62a3ef7a24ffd3ed2c3162_LORA-BIAS-7B-v21.pth",
282283
# "LORA16-7B": "",
283284
# "PARTIAL-7B": ""
284285
}
285286

286287
def available_models():
287288
return list(_MODELS.keys())
288289

289-
def load(name, llama_dir, device="cuda" if torch.cuda.is_available() else "cpu", download_root='ckpts', max_seq_len=512,
290+
def load(name, llama_dir, llama_type="7B", device="cuda" if torch.cuda.is_available() else "cpu", download_root='ckpts', max_seq_len=512,
290291
phase="finetune"):
291292
if name in _MODELS:
292293
model_path = _download(_MODELS[name], download_root)
@@ -296,7 +297,7 @@ def load(name, llama_dir, device="cuda" if torch.cuda.is_available() else "cpu",
296297
return RuntimeError(f"Model {name} not found; available models = {available_models()}"), None
297298

298299
# BIAS-7B or https://xxx/sha256_BIAS-7B.pth -> 7B
299-
llama_type = name.split('.')[0].split('-')[-1]
300+
# llama_type = name.split('.')[0].split('-')[-1]
300301
llama_ckpt_dir = os.path.join(llama_dir, llama_type)
301302
llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model')
302303

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import os
2+
import glob
3+
import argparse
4+
from tqdm import tqdm
5+
import PIL
6+
from PIL import Image
7+
import torch
8+
import torch.distributed as dist
9+
from torch.utils.data import Dataset
10+
import cv2
11+
from llama.llama_adapter import LLaMA_adapter
12+
13+
DATA_DIR = "./MME_Benchmark_release_version"
14+
15+
def get_image(image):
16+
if type(image) is str:
17+
try:
18+
return Image.open(image).convert("RGB")
19+
except Exception as e:
20+
print(f"Fail to read image: {image}")
21+
exit(-1)
22+
elif type(image) is Image.Image:
23+
return image
24+
elif type(image) is PIL.JpegImagePlugin.JpegImageFile:
25+
return image
26+
elif type(image) is PIL.PngImagePlugin.PngImageFile:
27+
return image
28+
elif type(image) is PIL.MpoImagePlugin.MpoImageFile:
29+
return image
30+
else:
31+
raise NotImplementedError(f"Invalid type of Image: {type(image)}")
32+
33+
34+
class MMEDataset(Dataset):
35+
def __init__(
36+
self,
37+
dataset_name
38+
):
39+
self.dataset_name = dataset_name
40+
self.dataset = []
41+
jpg_sets = ["artwork", "celebrity", "color", "count", "existence", "landmark", "OCR", "position", "posters", "scene"]
42+
png_sets = ["code_reasoning", "commonsense_reasoning", "numerical_calculation", "text_translation"]
43+
image_suffix = '.jpg' if dataset_name in jpg_sets else ".png"
44+
45+
assert (dataset_name in jpg_sets) or (dataset_name in png_sets), f"Invalid dataset name for MME benchmark: {dataset_name}"
46+
47+
if os.path.exists(f"{DATA_DIR}/{dataset_name}/images") and os.path.exists(f"{DATA_DIR}/{dataset_name}/questions_answers_YN"):
48+
question_files = os.listdir(f"{DATA_DIR}/{dataset_name}/questions_answers_YN")
49+
for question_file in question_files:
50+
image_file_name = os.path.join(DATA_DIR, dataset_name, "images", question_file.replace('.txt', image_suffix))
51+
with open(os.path.join(DATA_DIR, dataset_name, "questions_answers_YN", question_file), 'r', encoding='utf-8') as f:
52+
for line in f.readlines():
53+
try:
54+
question, gt_answer = line.replace('\n', '').split('\t')
55+
self.dataset.append({
56+
"image_path": image_file_name,
57+
"gt_answers": gt_answer,
58+
"question": question
59+
})
60+
except:
61+
pass
62+
63+
else:
64+
question_files = glob.glob(f"{DATA_DIR}/{dataset_name}/*.txt")
65+
for question_file in question_files:
66+
image_file_name = question_file.replace(".txt", image_suffix)
67+
with open(question_file, 'r', encoding='utf-8') as f:
68+
for line in f.readlines():
69+
try:
70+
question, gt_answer = line.replace('\n', '').split('\t')
71+
self.dataset.append({
72+
"image_path": image_file_name,
73+
"gt_answers": gt_answer,
74+
"question": question
75+
})
76+
except:
77+
pass
78+
79+
def __len__(self):
80+
return len(self.dataset)
81+
82+
def __getitem__(self, idx):
83+
return self.dataset[idx]
84+
85+
86+
def get_args_parser():
87+
parser = argparse.ArgumentParser('Single-turn (conversation) demo', add_help=False)
88+
# Model parameters
89+
parser.add_argument('--llama_path', default='/path/to/llama', type=str,
90+
help='path to LLaMA pretrained checkpoint')
91+
parser.add_argument('--pretrained_path', default='/path/to/pretrained', type=str,
92+
help='directory containing pre-trained checkpoints')
93+
parser.add_argument('--lora', default=16, type=int)
94+
parser.add_argument('--output_path', default='/path/to/output_results', type=str)
95+
return parser
96+
97+
98+
if __name__ == "__main__":
99+
args = get_args_parser().parse_args()
100+
101+
device = "cuda" if torch.cuda.is_available() else "cpu"
102+
103+
llama_dir = args.llama_path
104+
llama_type = '7B'
105+
llama_ckpt_dir = os.path.join(llama_dir, llama_type)
106+
llama_tokenzier_path = os.path.join(llama_dir, 'tokenizer.model')
107+
108+
model_path = args.pretrained_path
109+
# load llama_adapter weights and model_cfg
110+
print(f'Loading LLaMA-Adapter from {model_path}')
111+
ckpt = torch.load(model_path, map_location='cpu')
112+
113+
w_bias = True
114+
w_lora = args.lora > 0
115+
print('Lora:', w_lora)
116+
lora_rank = args.lora
117+
model = LLaMA_adapter(
118+
llama_ckpt_dir, llama_tokenzier_path,
119+
max_seq_len=512, max_batch_size=1,
120+
clip_model='ViT-L/14',
121+
v_embed_dim=768, v_depth=8,
122+
v_num_heads=16, v_mlp_ratio=4.0,
123+
query_len=10, query_layer=31,
124+
w_bias=w_bias,
125+
w_lora=w_lora,
126+
lora_rank=lora_rank,
127+
w_new_gate=w_lora, # for compatibility
128+
phase='finetune')
129+
130+
load_result = model.load_state_dict(ckpt['model'], strict=False)
131+
print(load_result)
132+
133+
model = model.to(device)
134+
model.half()
135+
model.eval()
136+
preprocess = model.clip_transform
137+
138+
prompt_format = (
139+
"Below is an instruction that describes a task. "
140+
"Write a response that appropriately completes the request using a single word or phrase.\n\n"
141+
"### Instruction:\n{instruction}\n\n### Response:"
142+
)
143+
144+
def multi_modal_generate(
145+
img_path: str,
146+
prompt: str,
147+
max_gen_len=30,
148+
temperature: float = 0,
149+
top_p: float = 0.75,
150+
):
151+
img = Image.fromarray(cv2.imread(img_path))
152+
img = preprocess(img).unsqueeze(0).half().to(device)
153+
prompt = prompt_format.format_map({'instruction': prompt})
154+
155+
result = model.generate(img, [prompt],
156+
max_gen_len=max_gen_len,
157+
temperature=temperature,
158+
top_p=top_p)
159+
return result[0]
160+
161+
162+
result = {}
163+
dataset_names = ["artwork", "celebrity", "color", "count", "existence", "OCR", "position", "posters", "scene", "code_reasoning", "commonsense_reasoning", "numerical_calculation", "text_translation", "landmark"] # landmark (03d5e3bfc958be38.jpg)
164+
answer_path = args.output_path
165+
batch_size = 1
166+
167+
print("Starting...")
168+
for dataset_name in dataset_names:
169+
dataset = MMEDataset(dataset_name)
170+
171+
predictions = []
172+
with torch.no_grad():
173+
for data in tqdm(dataset, desc=f"Inferencing {dataset_name}"):
174+
pred = multi_modal_generate(data['image_path'], data['question'])
175+
predictions.append({'image_path': data['image_path'], 'question': data['question'], 'answer': pred, 'gt_answers': data['gt_answers']})
176+
177+
os.makedirs(answer_path, exist_ok=True)
178+
prediction_file = os.path.join(answer_path, f"{dataset_name}.txt")
179+
out_datas = [
180+
f"{data['image_path']}\t{data['question']}\t{data['gt_answers']}\t{data['answer']}"
181+
for data in predictions
182+
]
183+
with open(prediction_file, 'w') as f:
184+
f.write('\n'.join(out_datas))

0 commit comments

Comments
 (0)