Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 25 additions & 8 deletions llmc/compression/token_reduction/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from types import MethodType

import torch
from loguru import logger

from llmc.utils.registry_factory import TOKEN_REDUCTION_REGISTRY

Expand Down Expand Up @@ -62,22 +63,37 @@ def input_hook(module, input_args, pruning_paras):
@prefill_wrapper
def random_pruning_hook(module, args, kwargs, pruning_paras):

logger.info(' ========random_pruning_hook======== ')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider removing this logger.info call as it seems to be for debugging purposes. Leaving it in could clutter the logs.


rate = pruning_paras['rate']
image_token_start_index = pruning_paras['image_token_start_index']
image_token_length = pruning_paras['image_token_length']

hidden_states = args[0]
causal_mask = kwargs['attention_mask']

logger.info(f'before hidden_states : {hidden_states.shape}')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logger.info call appears to be for debugging. It should be removed to avoid unnecessary logging in production.


device = hidden_states.device
vision_indexes = torch.arange(
image_token_start_index,
image_token_start_index + image_token_length,
device=device,
)
num_keep = round(image_token_length * (1 - rate))
rand_idx = torch.randperm(image_token_length, device=device)[:num_keep]
vision_indexes = vision_indexes[rand_idx]

if self.model.first_turn_question:
logger.info(' -----first_turn_question-----')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logger.info call seems to be for debugging. It should be removed to keep production logs clean.

vision_indexes = torch.arange(
image_token_start_index,
image_token_start_index + image_token_length,
device=device,
)
num_keep = round(image_token_length * (1 - rate))
rand_idx = torch.randperm(image_token_length, device=device)[:num_keep]
vision_indexes = vision_indexes[rand_idx]

# save vision_indexes to module
module.register_buffer('vision_indexes', vision_indexes)
else:
logger.info(' -----not first_turn_question-----')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logger.info call appears to be for debugging. It should be removed to avoid unnecessary logging in production.

# load vision_indexes from module (prompt cache)
vision_indexes = module.vision_indexes

# keep index
keep_indexs = torch.cat(
(
Expand Down Expand Up @@ -115,6 +131,7 @@ def random_pruning_hook(module, args, kwargs, pruning_paras):
position_embeddings[0].resize_as_(new_pe0).copy_(new_pe0)
position_embeddings[1].resize_as_(new_pe0).copy_(new_pe1)

logger.info(f'after hidden_states : {hidden_states.shape}')

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider removing this logger.info call as it seems to be for debugging purposes. Leaving it in could clutter the logs.

return (hidden_states,), kwargs

if self.model.__class__.__name__ == 'LlavaHf':
Expand Down
1 change: 1 addition & 0 deletions llmc/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .eval_acc import AccuracyEval
from .eval_code import HumanEval
from .eval_custom_generate import CustomGenerate
from .eval_custom_generate_just_infer import CustomGenerateJustInfer
from .eval_ppl import DecodePerplexityEval, PerplexityEval
from .eval_token_consist import TokenConsistencyEval
from .eval_video_generate import VideoGenerateEval
Expand Down
36 changes: 36 additions & 0 deletions llmc/eval/eval_custom_generate_just_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import glob
import json
import os

import torch
from human_eval.data import stream_jsonl, write_jsonl
from human_eval.evaluation import evaluate_functional_correctness
from loguru import logger
from tqdm import tqdm

from .eval_base import BaseEval


class CustomGenerateJustInfer:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The class CustomGenerateJustInfer should inherit from BaseEval to maintain consistency with other evaluation classes and ensure compatibility with the evaluation framework.

Suggested change
class CustomGenerateJustInfer:
class CustomGenerateJustInfer(BaseEval):

def __init__(self, model, config):
self.model = model
self.config = config
self.eval_cfg = config.eval

@torch.no_grad()
def eval(self, model, eval_pos=None):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The model parameter in the eval method is not used and shadows the class attribute self.model. Remove it to avoid confusion.

Suggested change
def eval(self, model, eval_pos=None):
def eval(self, eval_pos=None):

logger.info('start inference')

with open(os.path.join(self.eval_cfg.path, 'samples.json'), 'r') as f:
questions_list = json.load(f)

custom_samples_ans = self.model.eval_custom_samples_just_infer(
questions_list,
self.eval_cfg
)

with open(os.path.join('custom_samples_ans.json'), 'w') as f:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output filename custom_samples_ans.json is hardcoded. Consider saving the output in the same directory as the input data, using self.eval_cfg.path for better flexibility.

Suggested change
with open(os.path.join('custom_samples_ans.json'), 'w') as f:
with open(os.path.join(self.eval_cfg.path, 'custom_samples_ans.json'), 'w') as f:

json.dump(custom_samples_ans, f, indent=4)

torch.cuda.empty_cache()
return 'custom gen done.'
8 changes: 5 additions & 3 deletions llmc/eval/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

from loguru import logger

from llmc.eval import (AccuracyEval, CustomGenerate, DecodePerplexityEval,
HumanEval, PerplexityEval, TokenConsistencyEval,
VideoGenerateEval, VQAEval)
from llmc.eval import (AccuracyEval, CustomGenerate, CustomGenerateJustInfer,
DecodePerplexityEval, HumanEval, PerplexityEval,
TokenConsistencyEval, VideoGenerateEval, VQAEval)
from llmc.utils import deploy_all_modality


Expand Down Expand Up @@ -57,6 +57,8 @@ def get_eval_list(model, config):
eval_class = HumanEval(model, config_for_eval)
elif config_tmp.eval.type == 'generate_only':
eval_class = CustomGenerate(model, config_for_eval)
elif config_tmp.eval.type == 'just_infer':
eval_class = CustomGenerateJustInfer(model, config_for_eval)
elif config_tmp.eval.type == 'token_acc':
eval_class = TokenConsistencyEval(model, config_for_eval)
elif config_tmp.eval.type == 'ppl':
Expand Down
101 changes: 98 additions & 3 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import types
from datetime import timedelta
from typing import Optional, Union
Expand All @@ -9,6 +10,7 @@
from lmms_eval.models.llava import Llava as LLaVA
from loguru import logger
from packaging import version
from PIL import Image
from transformers import AutoConfig, AutoTokenizer

from llmc.utils.registry_factory import MODEL_REGISTRY
Expand All @@ -17,8 +19,11 @@

try:
from llava.constants import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN,
DEFAULT_IMAGE_PATCH_TOKEN, IMAGE_TOKEN_INDEX)
from llava.mm_utils import get_model_name_from_path
DEFAULT_IMAGE_PATCH_TOKEN,
DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX)
from llava.conversation import SeparatorStyle, conv_templates
from llava.mm_utils import (get_model_name_from_path, process_images,
tokenizer_image_token)
from llava.model.builder import load_pretrained_model
from llava.model.language_model.llava_llama import LlavaConfig
except Exception as e:
Expand All @@ -45,7 +50,7 @@ def build_model(self):
self.vlm_model_config.use_cache = True
logger.info(f'self.vlm_model_config : {self.vlm_model_config}')

self.tokenizer, self.vlm_model, image_processor, context_len = load_pretrained_model(
self.tokenizer, self.vlm_model, self.image_processor, context_len = load_pretrained_model(
self.model_path,
None,
get_model_name_from_path(self.model_path),
Expand Down Expand Up @@ -137,6 +142,96 @@ def get_subsets_in_block(self, block):
else:
raise Exception(f'Llava do not support {self.get_modality()} modality.')

def eval_custom_samples_just_infer(
self,
img_qas,
eval_cfg
): # noqa

custom_samples_ans = img_qas.copy()

self.vlm_model.cuda()

def load_image(image_file):
image = Image.open(image_file).convert('RGB')
return image

def load_images(image_files):
out = []
for image_file in image_files:
image = load_image(image_file)
out.append(image)
return out

self.first_turn_question = True

for data_idx, questions in enumerate(img_qas):
self.first_turn_question = True

custom_samples_ans[data_idx]['answer'] = []

image_files = questions['image']
image_files = [os.path.join(eval_cfg.path, 'images', image_file) for image_file in image_files] # noqa
images = load_images(image_files)
image_sizes = [x.size for x in images]
images_tensor = process_images(
images,
self.image_processor,
self.vlm_model.config
).to(self.vlm_model.device, dtype=torch.float16)

input_ids_old = None

for question_idx, question in enumerate(questions['question']):

conv_mode = 'llava_v1'
conv = conv_templates[conv_mode].copy()
if question_idx > 0:
conv.system = ''
qs = question
self.first_turn_question = False
else:
qs = DEFAULT_IMAGE_TOKEN + '\n' + question
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()

input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors='pt').unsqueeze(0).cuda() # noqa
# print(f"input_ids 1: {input_ids}, {input_ids.shape}")
if input_ids_old is not None:
input_ids = torch.cat((input_ids_old, input_ids), dim=1)
# print(f"input_ids 2: {input_ids}, {input_ids.shape}")

with torch.inference_mode():
output_ids = self.vlm_model.generate(
input_ids,
attention_mask=input_ids.new_ones(input_ids.shape, dtype=torch.bool),
images=images_tensor,
image_sizes=image_sizes,
do_sample=False,
top_p=None,
num_beams=1,
max_new_tokens=eval_cfg.max_new_tokens,
use_cache=True,
)

# print(f"output_ids: {output_ids}, {output_ids.shape}")

outputs = self.tokenizer.batch_decode(output_ids, skip_special_tokens=True)

print('--------------------------------')
print(f'data_idx: {data_idx}')
print(f'question_idx: {question_idx}')
print(f'question: {question}')
print(f'outputs: {outputs}')
print('--------------------------------')
Comment on lines +222 to +227

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements appear to be for debugging. They should be removed or replaced with structured logging (e.g., logger.debug) if this information is valuable for debugging, to avoid cluttering the standard output.


custom_samples_ans[data_idx]['answer'].append(outputs[0])

input_ids_old = torch.cat((input_ids, output_ids), dim=1)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The history input_ids_old is updated by concatenating input_ids with output_ids. Since output_ids already includes input_ids, this leads to duplicated context. Update the history by assigning output_ids to input_ids_old.

Suggested change
input_ids_old = torch.cat((input_ids, output_ids), dim=1)
input_ids_old = output_ids


return custom_samples_ans


if version.parse(torch.__version__) >= version.parse('2.1.2'):
best_fit_attn_implementation = 'sdpa'
Expand Down