Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 98 additions & 10 deletions vllm/model_executor/models/ovis2_5.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig

from vllm.config import VllmConfig
from vllm.logger import init_logger

from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
Expand All @@ -29,9 +31,11 @@
from vllm.multimodal.profiling import BaseDummyInputsBuilder
from vllm.sequence import IntermediateTensors
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor

from vllm.worker.hpu_model_runner import VisionBuckets
from .interfaces import MultiModalEmbeddings, SupportsMultiModal

logger = init_logger(__name__)

IMAGE_TOKEN = "<image>"
VIDEO_TOKEN = "<video>"
INDICATOR_IDS = [-301, -302, -303, -304]
Expand Down Expand Up @@ -416,6 +420,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):

text_model_type = self.config.get_text_config().model_type
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
self.vision_buckets = VisionBuckets(is_batch_based=False)

# TODO(Isotr0py): PP support
# self.make_empty_intermediate_tensors = (
Expand Down Expand Up @@ -458,30 +463,113 @@ def _parse_and_validate_visual_input(

raise AssertionError("This line should be unreachable.")

def find_factor(self, desired_patches, orig):
for i in range(orig + 1, desired_patches + 1):
if desired_patches % i == 0:
if i % 2 != 0:
continue
else:
return i
return None

def find_padding(self, h_orig, w_orig, desired_patches):
best_pad_h, best_pad_w = 0, 0
if desired_patches % h_orig == 0:
best_pad_h = 0
w_factor = desired_patches // h_orig
best_pad_w = w_factor - w_orig if (w_factor > w_orig
and w_factor % 2 == 0) else 0
elif desired_patches % w_orig == 0:
best_pad_w = 0
h_factor = desired_patches // w_orig
best_pad_h = h_factor - h_orig if (h_factor > h_orig
and h_factor % 2 == 0) else 0
elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0:
if h_orig > w_orig:
w_factor = self.find_factor(desired_patches, w_orig)
if w_factor is not None:
best_pad_w = w_factor - w_orig
h_factor = desired_patches // w_factor
if h_factor > h_orig:
best_pad_h = h_factor - h_orig
else:
h_factor = self.find_factor(desired_patches, h_orig)
if h_factor is not None:
best_pad_h = h_factor - h_orig
w_factor = desired_patches // h_factor
if w_factor > w_orig:
best_pad_w = w_factor - w_orig

if (best_pad_h + h_orig) * (best_pad_w + w_orig) != desired_patches:
best_pad_h, best_pad_w = 0, 0

return best_pad_h, best_pad_w

def pad_multimodal_data(self, pixel_values, image_grid_thw,
vision_buckets):
desired_number_of_pixels = vision_buckets.get_multimodal_bucket(
pixel_values.shape[0])
padding_len = desired_number_of_pixels - pixel_values.shape[0]
if padding_len <= 0:
return pixel_values, image_grid_thw

logger_msg = "Padding current number pixel " \
+ str(pixel_values.shape[0]) \
+ " to " \
+ str(desired_number_of_pixels)
logger.info(logger_msg)

h_orig, w_orig = image_grid_thw[0, 1].item(), image_grid_thw[0,
2].item()
pad_h, pad_w = self.find_padding(h_orig, w_orig,
desired_number_of_pixels)
if pad_h == 0 and pad_w == 0:
return pixel_values, image_grid_thw

constant_value = -100
pixel_values = torch.cat([
pixel_values,
torch.ones((padding_len, pixel_values.shape[1]),
device=pixel_values.device) * constant_value
])

image_grid_thw = torch.tensor([[1, h_orig + pad_h, w_orig + pad_w]],
device=image_grid_thw.device,
dtype=image_grid_thw.dtype)

assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels
return pixel_values, image_grid_thw

def _process_image_input(
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
image_patches_flat = image_input["flat_data"]
patches_per_image = image_input["patches_per_image"]
indicator_tokens = image_input["indicator_tokens"]
grid_thws = image_input["grids"]

indicator_per_image = list(
map(lambda x: 2 if x > 1 else x + 2, patches_per_image))

target_dtype = self.visual_tokenizer.dtype
visual_tokens = self.visual_tokenizer(
image_patches_flat.to(target_dtype), grid_thws)

visual_embeds, grid_thws = self.pad_multimodal_data(
image_patches_flat.to(target_dtype), grid_thws,
self.vision_buckets)

visual_tokens = self.visual_tokenizer(visual_embeds, grid_thws)
visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
indicator_embeds = self.vte(indicator_tokens)
padded_patches_per_image = [
grid[1] * grid[2] // (self.config.vit_config.hidden_stride**2)
for grid in grid_thws
]

visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
visual_embeds_per_image = visual_embeds.split(padded_patches_per_image,
dim=0)
indicator_per_image = list(
map(lambda x: 2 if x > 1 else x + 2, padded_patches_per_image))
indicator_embeds_per_image = indicator_embeds.split(
indicator_per_image)

vision_embeddings = []
for indicator, visual in zip(indicator_embeds_per_image,
visual_embeds_per_image):
for idx, (indicator, visual) in enumerate(
zip(indicator_embeds_per_image, visual_embeds_per_image)):
vision_embeddings_per_image = []
visual = visual.unsqueeze(0)
for i in range(visual.shape[0]):
Expand Down
2 changes: 1 addition & 1 deletion vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, model):
else:
self.is_batch_based = False
multimodal_buckets = [
1600, 3136, 4096, 6400, 7744, 9216, 12544
784, 1600, 3136, 4096, 6400, 7744, 9216, 12544
]
else:
multimodal_buckets = [int(i) for i in envvar.split(',')]
Expand Down
Loading