Skip to content

Commit 1676c57

Browse files
authored
feat: add llava_onevision1_5 (#825)
* feat: add llava_onevision1_5 * unset interleave_visuals * fix: useless class, error link, * fix: change repo * improve: bot advice * improve: bot advice.2 * fix format * Remove invalid claude.yml workflow file * 保存当前工作进度 * fix link * fix link
1 parent 7ee9506 commit 1676c57

File tree

4 files changed

+342
-2
lines changed

4 files changed

+342
-2
lines changed

.github/workflows/claude.yml

Lines changed: 0 additions & 2 deletions
This file was deleted.
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
export HF_HOME="~/.cache/huggingface"
2+
3+
# pip install git+https://github.com/EvolvingLMMs-Lab/lmms-eval.git
4+
5+
accelerate launch --num_processes=8 --main_process_port 12399 -m lmms_eval \
6+
--model=llava_onevision1_5 \
7+
--model_args=pretrained=lmms-lab/LLaVA-OneVision-1.5-8B-Instruct,attn_implementation=flash_attention_2,max_pixels=3240000 \
8+
--tasks=mmerealworld,mmerealworld_cn,chartqa,docvqa_val,infovqa_val,mmstar,ocrbench \
9+
--batch_size=1

lmms_eval/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"llava": "Llava",
3838
"llava_hf": "LlavaHf",
3939
"llava_onevision": "Llava_OneVision",
40+
"llava_onevision1_5": "Llava_OneVision1_5",
4041
"llava_onevision_moviechat": "Llava_OneVision_MovieChat",
4142
"llava_sglang": "LlavaSglang",
4243
"llava_vid": "LlavaVid",
Lines changed: 332 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,332 @@
1+
import base64
2+
import re
3+
from io import BytesIO
4+
from typing import List, Optional, Tuple, Union
5+
6+
import decord
7+
import numpy as np
8+
import torch
9+
from accelerate import Accelerator, DistributedType
10+
from loguru import logger as eval_logger
11+
from PIL import Image
12+
from tqdm import tqdm
13+
from transformers import AutoModelForCausalLM, AutoProcessor, AutoTokenizer
14+
15+
from lmms_eval import utils
16+
from lmms_eval.api.instance import Instance
17+
from lmms_eval.api.model import lmms
18+
from lmms_eval.api.registry import register_model
19+
20+
try:
21+
from qwen_vl_utils import process_vision_info
22+
except ImportError:
23+
eval_logger.warning("Failed to import qwen_vl_utils; Please install it via `pip install qwen-vl-utils`")
24+
25+
26+
@register_model("llava_onevision1_5")
27+
class Llava_OneVision1_5(lmms):
28+
"""
29+
Llava_OneVision1_5 Model
30+
"https://huggingface.co/lmms-lab/LLaVA-OneVision-1.5-8B-Instruct"
31+
"""
32+
33+
def __init__(
34+
self,
35+
pretrained: str = "lmms-lab/LLaVA-OneVision-1.5-8B-Instruct",
36+
device: Optional[str] = "cuda",
37+
device_map: Optional[str] = "auto",
38+
batch_size: Optional[Union[int, str]] = 1,
39+
use_cache=True,
40+
attn_implementation: Optional[str] = None,
41+
min_pixels: int = 256 * 28 * 28,
42+
max_pixels: int = 1605632,
43+
max_num_frames: int = 32,
44+
use_custom_video_loader: Optional[bool] = False,
45+
fps: Optional[float] = None, # Only applicable if use_custom_video_loader is True
46+
max_image_size: Optional[int] = None, # Only applicable if use_custom_video_loader is True
47+
system_prompt: Optional[str] = "You are a helpful assistant.",
48+
interleave_visuals: Optional[bool] = False,
49+
reasoning_prompt: Optional[str] = None,
50+
max_length: int = 2048,
51+
**kwargs,
52+
) -> None:
53+
super().__init__()
54+
if kwargs:
55+
eval_logger.warning(f"Ignoring unexpected kwargs: {list(kwargs.keys())}")
56+
57+
# Validate attention implementation
58+
valid_attn_implementations = [None, "flash_attention_2", "sdpa", "eager"]
59+
if attn_implementation not in valid_attn_implementations:
60+
raise ValueError(f"attn_implementation must be one of {valid_attn_implementations}, got {attn_implementation}")
61+
62+
self.use_custom_video_loader = use_custom_video_loader
63+
self.fps = fps
64+
# if self.fps and not self.use_custom_video_loader:
65+
# raise ValueError("FPS is only applicable if use_custom_video_loader is True")
66+
self.max_image_size = max_image_size
67+
if self.max_image_size and not self.use_custom_video_loader:
68+
raise ValueError("max_image_size is only applicable if use_custom_video_loader is True")
69+
70+
accelerator = Accelerator()
71+
if accelerator.num_processes > 1:
72+
self._device = torch.device(f"cuda:{accelerator.local_process_index}")
73+
self.device_map = f"cuda:{accelerator.local_process_index}"
74+
else:
75+
self._device = torch.device(device)
76+
self.device_map = device_map if device_map else device
77+
78+
# Prepare model loading arguments
79+
model_kwargs = {"torch_dtype": "auto", "device_map": self.device_map, "trust_remote_code": True}
80+
81+
# Add attention implementation if specified
82+
if attn_implementation is not None:
83+
model_kwargs["attn_implementation"] = attn_implementation
84+
85+
self._model = AutoModelForCausalLM.from_pretrained(pretrained, **model_kwargs).eval()
86+
self.max_pixels = max_pixels
87+
self.min_pixels = min_pixels
88+
self.max_num_frames = max_num_frames
89+
90+
if reasoning_prompt:
91+
self.reasoning_prompt = reasoning_prompt.replace("\\n", "\n")
92+
else:
93+
self.reasoning_prompt = None
94+
self.processor = AutoProcessor.from_pretrained(pretrained, max_pixels=max_pixels, min_pixels=min_pixels, trust_remote_code=True)
95+
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, trust_remote_code=True)
96+
self.system_prompt = system_prompt
97+
self.interleave_visuals = interleave_visuals
98+
99+
self._config = self.model.config
100+
self._max_length = int(max_length)
101+
self.batch_size_per_gpu = int(batch_size)
102+
self.use_cache = use_cache
103+
104+
if accelerator.num_processes > 1:
105+
assert accelerator.distributed_type in [
106+
DistributedType.FSDP,
107+
DistributedType.MULTI_GPU,
108+
], "Unsupported distributed type provided. Only DDP and FSDP are supported."
109+
if accelerator.distributed_type == DistributedType.FSDP:
110+
self._model = accelerator.prepare(self.model)
111+
else:
112+
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
113+
self.accelerator = accelerator
114+
if self.accelerator.is_local_main_process:
115+
eval_logger.info(f"Using {accelerator.num_processes} devices with data parallelism")
116+
self._rank = self.accelerator.local_process_index
117+
self._world_size = self.accelerator.num_processes
118+
else:
119+
self._rank = 0
120+
self._world_size = 1
121+
122+
@property
123+
def config(self):
124+
# return the associated transformers.AutoConfig for the given pretrained model.
125+
return self._config
126+
127+
@property
128+
def tokenizer(self):
129+
return self._tokenizer
130+
131+
@property
132+
def model(self):
133+
# returns the model, unwrapping it if using Accelerate
134+
if hasattr(self, "accelerator"):
135+
return self.accelerator.unwrap_model(self._model)
136+
else:
137+
return self._model
138+
139+
@property
140+
def eot_token_id(self):
141+
return self.tokenizer.eos_token_id
142+
143+
@property
144+
def max_length(self):
145+
return self._max_length
146+
147+
@property
148+
def batch_size(self):
149+
return self.batch_size_per_gpu
150+
151+
@property
152+
def device(self):
153+
return self._device
154+
155+
@property
156+
def rank(self):
157+
return self._rank
158+
159+
@property
160+
def world_size(self):
161+
return self._world_size
162+
163+
def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
164+
raise NotImplementedError("Loglikelihood is not implemented for Qwen2.5_VL")
165+
166+
def flatten(self, input):
167+
new_list = []
168+
for i in input:
169+
for j in i:
170+
new_list.append(j)
171+
return new_list
172+
173+
def generate_until(self, requests: List[Instance]) -> List[str]:
174+
res = []
175+
176+
def _collate(x):
177+
# the negative sign on len(toks) sorts descending - this has a few advantages:
178+
# - time estimates will always be over not underestimates, which is more useful for planning
179+
# - to know the size of a batch when going through the list, you know the first one is always the batch
180+
# padded context length. this is useful to simplify the batching logic and more importantly to make
181+
# automatic adaptive batches much much easier to implement
182+
# - any OOMs will happen right away rather than near the end
183+
toks = self.tokenizer.encode(x[0])
184+
return -len(toks), x[0]
185+
186+
pbar = tqdm(total=len(requests), disable=(self.rank != 0), desc="Model Responding")
187+
# we group requests by their generation_kwargs,
188+
# so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
189+
# in the same batch.
190+
re_ords = utils.Collator([reg.args for reg in requests], _collate, grouping=True)
191+
chunks = re_ords.get_batched(n=self.batch_size, batch_fn=None)
192+
for chunk in chunks:
193+
contexts, all_gen_kwargs, doc_to_visual, doc_id, task, split = zip(*chunk)
194+
task = task[0]
195+
split = split[0]
196+
visual_list = [doc_to_visual[0](self.task_dict[task][split][ids]) for ids in doc_id]
197+
gen_kwargs = all_gen_kwargs[0]
198+
199+
# Set default until or update values from gen_kwargs if present
200+
until = gen_kwargs.get("until", [self.tokenizer.decode(self.eot_token_id)])
201+
202+
if isinstance(until, str):
203+
until = [until]
204+
elif not isinstance(until, list):
205+
raise ValueError(f"Expected `gen_kwargs['until']` to be of type Union[str, list], but got {type(until)}")
206+
207+
# Avoid using '\n\n' as a stopper for Qwen2.5VL to prevent truncation, which can lead to incorrect results
208+
until = [item for item in until if item != "\n\n"]
209+
210+
if isinstance(contexts, tuple):
211+
contexts = list(contexts)
212+
213+
for i in range(len(contexts)):
214+
if "<image>" in contexts[i]:
215+
contexts[i] = contexts[i].replace("<image>", "")
216+
217+
batched_messages = []
218+
for i, context in enumerate(contexts):
219+
if "<image>" in context:
220+
context = context.replace("<image>", "")
221+
222+
message = [{"role": "system", "content": self.system_prompt}]
223+
if self.reasoning_prompt:
224+
context = context.strip() + self.reasoning_prompt
225+
contexts[i] = context
226+
227+
processed_visuals = []
228+
for visual in visual_list[i]:
229+
if isinstance(visual, str) and visual.endswith((".mp4", ".avi", ".mov")): # Video file
230+
vr = decord.VideoReader(visual)
231+
first_frame = vr[0].asnumpy()
232+
height, width = first_frame.shape[:2]
233+
# max_pixels = height * width
234+
processed_visuals.append({"type": "video", "video": visual, "max_pixels": self.max_pixels, "min_pixels": self.min_pixels})
235+
elif isinstance(visual, Image.Image):
236+
processed_visuals.append({"type": "image", "image": visual.convert("RGB")})
237+
238+
if self.interleave_visuals is False:
239+
message.append(
240+
{
241+
"role": "user",
242+
"content": processed_visuals + [{"type": "text", "text": context}],
243+
}
244+
)
245+
else: # currently support find <image x> in the context
246+
image_placeholders = re.findall(r"<image \d+>", context)
247+
content_parts = []
248+
text_parts = re.split(r"<image \d+>", context)
249+
if text_parts[0]:
250+
content_parts.append({"type": "text", "text": text_parts[0]})
251+
252+
for i, placeholder in enumerate(image_placeholders):
253+
img_idx = int(re.search(r"<image (\d+)>", placeholder).group(1)) - 1
254+
image_idx = min(img_idx, len(processed_visuals) - 1) if processed_visuals else 0
255+
if processed_visuals and image_idx < len(processed_visuals):
256+
content_parts.append(processed_visuals[image_idx])
257+
if i + 1 < len(text_parts) and text_parts[i + 1]:
258+
content_parts.append({"type": "text", "text": text_parts[i + 1]})
259+
260+
message.append(
261+
{
262+
"role": "user",
263+
"content": content_parts,
264+
}
265+
)
266+
267+
batched_messages.append(message)
268+
269+
texts = [self.processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True) for msg in batched_messages]
270+
image_inputs, video_inputs = process_vision_info(batched_messages)
271+
if video_inputs is not None:
272+
total_frames = video_inputs[0].shape[0]
273+
indices = np.linspace(0, total_frames - 1, self.max_num_frames, dtype=int)
274+
# Append the last frame index if not already included
275+
if total_frames - 1 not in indices:
276+
indices = np.append(indices, total_frames - 1)
277+
video_inputs[0] = video_inputs[0][indices]
278+
inputs = self.processor(text=texts, images=image_inputs, videos=video_inputs, padding=True, return_tensors="pt")
279+
280+
if self.device_map == "auto":
281+
inputs = inputs.to("cuda")
282+
else:
283+
inputs = inputs.to(self.device)
284+
285+
# Set default generation kwargs
286+
default_gen_kwargs = {
287+
"max_new_tokens": 128,
288+
"temperature": 0.0, # Set to 0 for greedy default
289+
"top_p": None,
290+
"num_beams": 1,
291+
}
292+
# Update with provided kwargs
293+
current_gen_kwargs = {**default_gen_kwargs, **gen_kwargs}
294+
pad_token_id = self.tokenizer.pad_token_id or self.tokenizer.eos_token_id
295+
do_sample = bool(current_gen_kwargs.get("temperature", 0) and current_gen_kwargs["temperature"] > 0)
296+
gen_args = {
297+
**inputs,
298+
"eos_token_id": self.tokenizer.eos_token_id,
299+
"pad_token_id": pad_token_id,
300+
"num_beams": current_gen_kwargs["num_beams"],
301+
"max_new_tokens": current_gen_kwargs["max_new_tokens"],
302+
"use_cache": self.use_cache,
303+
}
304+
if do_sample:
305+
gen_args.update(
306+
do_sample=True,
307+
temperature=float(current_gen_kwargs.get("temperature", 1.0)),
308+
top_p=float(current_gen_kwargs.get("top_p", 1.0)),
309+
)
310+
with torch.inference_mode():
311+
cont = self.model.generate(**gen_args)
312+
313+
generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, cont)]
314+
answers = self.processor.batch_decode(generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False)
315+
for i, ans in enumerate(answers):
316+
for term in until:
317+
if len(term) > 0:
318+
ans = ans.split(term)[0]
319+
answers[i] = ans
320+
321+
for ans, context in zip(answers, contexts):
322+
res.append(ans)
323+
self.cache_hook.add_partial("generate_until", (context, gen_kwargs), ans)
324+
pbar.update(1)
325+
# reorder this group of results back to original unsorted form
326+
res = re_ords.get_original(res)
327+
328+
pbar.close()
329+
return res
330+
331+
def generate_until_multi_round(self, requests) -> List[str]:
332+
raise NotImplementedError("TODO: Implement multi-round generation")

0 commit comments

Comments
 (0)