Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion llmc/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from .internvl2 import InternVL2
from .llama import Llama
from .llava import Llava
from .llava_lht import LlavaLHT
from .llava_hf import LlavaHf
from .minicpm import MiniCPM
from .minicpmv import MiniCPMV
from .mistral import Mistral
Expand Down
275 changes: 158 additions & 117 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,119 +1,117 @@
from typing import List, Optional, Tuple, Union
import types
from datetime import timedelta
from typing import Optional, Union

import torch
from accelerate import Accelerator, DistributedType
from accelerate import Accelerator, DistributedType, InitProcessGroupKwargs
from accelerate.state import AcceleratorState
from lmms_eval.api.model import lmms
from lmms_eval.models.llava_hf import LlavaHf
from lmms_eval.models.llava import Llava as LLaVA
from loguru import logger
from PIL import Image
from transformers import (AutoConfig, AutoProcessor,
LlavaForConditionalGeneration)
from packaging import version
from transformers import AutoConfig, AutoTokenizer

from llmc.utils.registry_factory import MODEL_REGISTRY

from .llama import Llama

try:
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN)
from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model
from llava.model.language_model.llava_llama import LlavaConfig
except Exception as e:
logger.debug('LLaVA is not installed. Please install LLaVA to use this model.\nError: %s' % e)


@MODEL_REGISTRY
class Llava(Llama):
def __init__(self, config, device_map=None, use_cache=False):
super().__init__(config, device_map, use_cache)

def build_tokenizer(self):
pass

def build_model(self):
self.llava_config = LlavaConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
self.vlm_model_config = AutoConfig.from_pretrained(
self.model_path, trust_remote_code=True
)
if not self.use_cache:
self.vlm_model_config.text_config.use_cache = False
self.llava_config.use_cache = False
self.vlm_model_config.use_cache = False
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')
self.vlm_model = LlavaForConditionalGeneration.from_pretrained(
self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model(
self.model_path,
config=self.vlm_model_config,
None,
get_model_name_from_path(self.model_path),
load_8bit=False,
load_4bit=False,
device='cpu',
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=True,
config=self.llava_config,
)
self.eval_name = 'LlavaHfEval'

# llava forward not support "cache_position"
ori_forward = self.vlm_model.forward

def safe_forward(*args, **kwargs):
kwargs['use_cache'] = False
kwargs.pop('cache_position', None)
return ori_forward(*args, **kwargs)
self.vlm_model.forward = safe_forward

# llava generate use "inputs" instead of "input_ids"
ori_generate = self.vlm_model.generate

def safe_generate(*args, **kwargs):
if 'input_ids' in kwargs:
kwargs['inputs'] = kwargs.pop('input_ids')
return ori_generate(*args, **kwargs)
self.vlm_model.generate = safe_generate

# "attention_mask" is passed via kwargs rather than as an explicit keyword argument.
ori_prepare_inputs_for_generation = self.vlm_model.prepare_inputs_for_generation

def safe_prepare_inputs_for_generation(
self, input_ids, past_key_values=None,
inputs_embeds=None, attention_mask=None, **kwargs):
if attention_mask is not None:
kwargs['attention_mask'] = attention_mask
return ori_prepare_inputs_for_generation(
input_ids, past_key_values, inputs_embeds, **kwargs)
self.vlm_model.prepare_inputs_for_generation = types.MethodType(
safe_prepare_inputs_for_generation, self.vlm_model
)

self.eval_name = 'LlavaEval'
self.mm_model = self.vlm_model
logger.info(f'self.vlm_model : {self.vlm_model}')
self.vision_model = self.vlm_model.vision_tower
self.vision_projector = self.vlm_model.multi_modal_projector
self.model = self.vlm_model.language_model
self.vision_model = self.vlm_model.get_vision_tower()
self.vision_projector = self.vlm_model.model.mm_projector
# Llava merges the language model with the vision projector and vision model
self.model = self.vlm_model
self.model_config = self.vlm_model_config.text_config
self.pruning_config = {
'is_video_model': False,
'image_token_start_index': 5,
'image_token_length': self.vlm_model_config.image_seq_length,
'select_layer': self.vlm_model_config.vision_feature_layer,
'select_feature': self.vlm_model_config.vision_feature_select_strategy,
'image_token_index': self.vlm_model_config.image_token_index,
'image_token_index': self.vlm_model_config.image_token_index
}

self.processor = AutoProcessor.from_pretrained(self.model_path)
self.processor = None

def get_extra_rot_module_besides_embed_layers(self):
return [self.vision_projector.linear_2]

def batch_process(
self,
img_qas,
calib_or_eval='eval',
apply_chat_template=True,
return_inputs=True,
): # noqa
assert calib_or_eval == 'calib' or calib_or_eval == 'eval'
assert apply_chat_template
messages = []
images = []
answers = []
for idx in range(len(img_qas)):
img_path = img_qas[idx]['image']
if img_path is not None:
image = Image.open(img_path)
message = [
{
'role': 'user',
'content': [
{'type': 'image'},
{'type': 'text', 'text': img_qas[idx]['question']},
],
}
]
images.append(image)
else:
message = [
{
'role': 'user',
'content': [{'type': 'text', 'text': img_qas[idx]['question']}],
}
]
messages.append(message)
answers.append(img_qas[idx]['answer'])
texts = [
self.processor.apply_chat_template(messages[n], add_generation_prompt=True)
for n in range(len(messages))
]
if calib_or_eval == 'calib' and self.config['calib'].get('add_answer', False):
texts = [texts[n] + ' ' + answers[n] for n in range(len(texts))]
if calib_or_eval == 'calib':
logger.info(f'Calib data is:\n{texts}')
if not return_inputs:
return texts
inputs = self.processor(
text=texts,
images=images if len(images) else None,
padding=True,
return_tensors='pt',
).to(
next(self.vlm_model.parameters()).dtype
) # noqa
return inputs
return [self.vision_projector[2]]

def find_blocks(self):
if self.get_modality() == 'language':
super().find_blocks()
elif self.get_modality() == 'vision':
self.blocks = self.vision_model.vision_model.encoder.layers
self.blocks = self.vision_model.vision_tower.vision_model.encoder.layers
else:
raise Exception(f'Llava do not support {self.get_modality()} modality.')

Expand Down Expand Up @@ -166,98 +164,141 @@ def get_subsets_in_block(self, block):
'inspect': block.mlp.fc2,
'has_kwargs': False,
'is_mlp': True,
'do_trans': False,
'do_trans': False
},
]
else:
raise Exception(f'Llava do not support {self.get_modality()} modality.')


if version.parse(torch.__version__) >= version.parse('2.1.2'):
best_fit_attn_implementation = 'sdpa'
else:
best_fit_attn_implementation = 'eager'


@MODEL_REGISTRY
class LlavaHfEval(LlavaHf):
class LlavaEval(LLaVA):
def __init__(
self,
llmc_model,
pretrained: str = 'llava-hf/llava-1.5-7b-hf',
revision: str = 'main',
device: str = 'cuda',
dtype: Optional[Union[str, torch.dtype]] = 'auto',
batch_size: int = 1,
trust_remote_code: Optional[bool] = False,
attn_implementation: Optional[str] = None,
pretrained: str = 'liuhaotian/llava-v1.5-7b',
truncation: Optional[bool] = True,
device: Optional[str] = 'cuda',
batch_size: Optional[Union[int, str]] = 1,
model_name=None,
attn_implementation=best_fit_attn_implementation,
device_map: str = '',
chat_template: Optional[str] = None,
conv_template='vicuna_v1',
use_cache: bool = False,
max_frames_num: Optional[int] = 32,
tie_weights: bool = True,
truncate_context=False, # set it False for LLaVA-1.6 no matter truncate
customized_config=None, # ends in json
**kwargs,
) -> None:

lmms.__init__(self)
# Do not use kwargs for now
assert kwargs == {}, f'Unexpected kwargs: {kwargs}'

accelerator = Accelerator()
if accelerator.num_processes > 1 and device_map == '':
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
self.accelerator = accelerator
if accelerator.num_processes > 1:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'
else:
elif accelerator.num_processes == 1 and device_map == 'auto':
self._device = torch.device(device)
self.device_map = device_map
if isinstance(dtype, str) and dtype != 'auto':
dtype = getattr(torch, dtype)
else:
self._device = torch.device(f'cuda:{accelerator.local_process_index}')
self.device_map = f'cuda:{accelerator.local_process_index}'

llava_model_args = {
'multimodal': True,
}
if customized_config is not None:
llava_model_args['customized_config'] = customized_config
if attn_implementation is not None:
llava_model_args['attn_implementation'] = attn_implementation
if 'use_flash_attention_2' in kwargs:
llava_model_args['use_flash_attention_2'] = kwargs['use_flash_attention_2']
model_name = model_name if model_name is not None else get_model_name_from_path(pretrained)

self._model = llmc_model.cuda()
self.pretrained = pretrained
self._image_processor = AutoProcessor.from_pretrained(
pretrained, revision=revision, trust_remote_code=trust_remote_code
)
# Pad from left for batched generation:
# https://huggingface.co/docs/transformers/v4.39.3/en/model_doc/llava#usage-tips
self._image_processor.tokenizer.padding_side = 'left'
self._tokenizer = self._image_processor.tokenizer
self._config = self._model.config
self._tokenizer = AutoTokenizer.from_pretrained(pretrained, use_fast=False)
self._image_processor = None
if 'llava' in model_name.lower():
mm_use_im_start_end = getattr(self._config, 'mm_use_im_start_end', False)
mm_use_im_patch_token = getattr(self._config, 'mm_use_im_patch_token', True)
if mm_use_im_patch_token:
self._tokenizer.add_tokens([DEFAULT_IMAGE_PATCH_TOKEN], special_tokens=True)
if mm_use_im_start_end:
self._tokenizer.add_tokens(
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN],
special_tokens=True
)
self._image_processor = self._model.get_vision_tower().image_processor
if hasattr(self._config, 'max_sequence_length'):
self._max_length = self._config.max_sequence_length
else:
self._max_length = 2048

self.model.eval()
if tie_weights:
self.model.tie_weights()

self.truncation = truncation
self.batch_size_per_gpu = int(batch_size)
self.chat_template = chat_template
self.conv_template = conv_template
self.use_cache = use_cache
if accelerator.num_processes > 1 and device_map == '':
self.truncate_context = truncate_context
# assert self.batch_size_per_gpu == 1, (
# "Llava currently does not support batched generation. "
# "See: https://github.com/haotian-liu/LLaVA/issues/754. "
# "HF Llava also has this issue."
# )
if accelerator.num_processes > 1:
assert accelerator.distributed_type in [
DistributedType.FSDP,
DistributedType.MULTI_GPU,
DistributedType.DEEPSPEED], (
'Unsupported distributed type provided. '
'Only DDP and FSDP are supported.')
# To use DistributedType.DEEPSPEED, run `accelerate config` first.
# You must select zero stage 0 (equivalent to DDP) for model preparation to work.
# Attempts to support zero stage 2 via kwargs failed.
if accelerator.distributed_type == DistributedType.DEEPSPEED:
kwargs = {
'train_micro_batch_size_per_gpu': self.batch_size_per_gpu,
'train_batch_size': self.batch_size_per_gpu
* accelerator.num_processes,
'train_batch_size': self.batch_size_per_gpu * accelerator.num_processes,
}
AcceleratorState().deepspeed_plugin.deepspeed_config_process(
must_match=True, **kwargs
)
logger.info(
'Detected that you are using DistributedType.DEEPSPEED. \
Make sure you run `accelerate config` and set zero stage to 0'
'Detected that you are using DistributedType.DEEPSPEED. '
'Make sure you run `accelerate config` and set zero stage to 0'
)

if (
accelerator.distributed_type == DistributedType.FSDP
or accelerator.distributed_type == DistributedType.DEEPSPEED
):
self._model = accelerator.prepare(self.model)
else:
self._model = accelerator.prepare_model(
self.model, evaluation_mode=True
)
self._model = accelerator.prepare_model(self.model, evaluation_mode=True)
self.accelerator = accelerator
if self.accelerator.is_local_main_process:
logger.info(
f'Using {accelerator.num_processes} devices with data parallelism'
)
logger.info(f'Using {accelerator.num_processes} devices with data parallelism')
self._rank = self.accelerator.local_process_index
self._world_size = self.accelerator.num_processes
elif accelerator.num_processes == 1 and device_map == 'auto':
logger.info(
f'Using {accelerator.num_processes} devices with pipeline parallelism'
)
logger.info(f'Using {accelerator.num_processes} devices with tensor parallelism')
self._rank = 0
self._word_size = 1

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There seems to be a slight inconsistency in variable naming here. In other parts of this conditional logic (e.g., new line 295 and 304), self._world_size is used, which is the standard term. For consistency and to avoid potential AttributeError if other code expects _world_size, should this be self._world_size = 1 as well?

Suggested change
self._word_size = 1
self._world_size = 1

else:
logger.info(f'Using single device: {self._device}')
self.model.to(self._device)
self._rank = 0
self._word_size = 1
self.accelerator = accelerator
self._world_size = 1
Loading