Skip to content

Commit 2fabbb1

Browse files
Add bucket and padding for Ovis
This pr is depending on #1873
1 parent bdd6f7c commit 2fabbb1

File tree

2 files changed

+98
-10
lines changed

2 files changed

+98
-10
lines changed

vllm/model_executor/models/ovis2_5.py

Lines changed: 97 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig
1111

1212
from vllm.config import VllmConfig
13+
from vllm.logger import init_logger
1314
from vllm.model_executor.layers.linear import ReplicatedLinear
1415
from vllm.model_executor.layers.quantization.base_config import (
1516
QuantizationConfig)
@@ -29,9 +30,12 @@
2930
from vllm.multimodal.profiling import BaseDummyInputsBuilder
3031
from vllm.sequence import IntermediateTensors
3132
from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor
33+
from vllm.worker.hpu_model_runner import VisionBuckets
3234

3335
from .interfaces import MultiModalEmbeddings, SupportsMultiModal
3436

37+
logger = init_logger(__name__)
38+
3539
IMAGE_TOKEN = "<image>"
3640
VIDEO_TOKEN = "<video>"
3741
INDICATOR_IDS = [-301, -302, -303, -304]
@@ -416,6 +420,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
416420

417421
text_model_type = self.config.get_text_config().model_type
418422
self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]
423+
self.vision_buckets = VisionBuckets(is_batch_based=False)
419424

420425
# TODO(Isotr0py): PP support
421426
# self.make_empty_intermediate_tensors = (
@@ -458,30 +463,113 @@ def _parse_and_validate_visual_input(
458463

459464
raise AssertionError("This line should be unreachable.")
460465

466+
def find_factor(self, desired_patches, orig):
467+
for i in range(orig + 1, desired_patches + 1):
468+
if desired_patches % i == 0:
469+
if i % 2 != 0:
470+
continue
471+
else:
472+
return i
473+
return None
474+
475+
def find_padding(self, h_orig, w_orig, desired_patches):
476+
best_pad_h, best_pad_w = 0, 0
477+
if desired_patches % h_orig == 0:
478+
best_pad_h = 0
479+
w_factor = desired_patches // h_orig
480+
best_pad_w = w_factor - w_orig if (w_factor > w_orig
481+
and w_factor % 2 == 0) else 0
482+
elif desired_patches % w_orig == 0:
483+
best_pad_w = 0
484+
h_factor = desired_patches // w_orig
485+
best_pad_h = h_factor - h_orig if (h_factor > h_orig
486+
and h_factor % 2 == 0) else 0
487+
elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0:
488+
if h_orig > w_orig:
489+
w_factor = self.find_factor(desired_patches, w_orig)
490+
if w_factor is not None:
491+
best_pad_w = w_factor - w_orig
492+
h_factor = desired_patches // w_factor
493+
if h_factor > h_orig:
494+
best_pad_h = h_factor - h_orig
495+
else:
496+
h_factor = self.find_factor(desired_patches, h_orig)
497+
if h_factor is not None:
498+
best_pad_h = h_factor - h_orig
499+
w_factor = desired_patches // h_factor
500+
if w_factor > w_orig:
501+
best_pad_w = w_factor - w_orig
502+
503+
if (best_pad_h + h_orig) * (best_pad_w + w_orig) != desired_patches:
504+
best_pad_h, best_pad_w = 0, 0
505+
506+
return best_pad_h, best_pad_w
507+
508+
def pad_multimodal_data(self, pixel_values, image_grid_thw,
509+
vision_buckets):
510+
desired_number_of_pixels = vision_buckets.get_multimodal_bucket(
511+
pixel_values.shape[0])
512+
padding_len = desired_number_of_pixels - pixel_values.shape[0]
513+
if padding_len <= 0:
514+
return pixel_values, image_grid_thw
515+
516+
logger_msg = "Padding current number pixel " \
517+
+ str(pixel_values.shape[0]) \
518+
+ " to " \
519+
+ str(desired_number_of_pixels)
520+
logger.info(logger_msg)
521+
522+
h_orig, w_orig = image_grid_thw[0, 1].item(), image_grid_thw[0,
523+
2].item()
524+
pad_h, pad_w = self.find_padding(h_orig, w_orig,
525+
desired_number_of_pixels)
526+
if pad_h == 0 and pad_w == 0:
527+
return pixel_values, image_grid_thw
528+
529+
constant_value = -100
530+
pixel_values = torch.cat([
531+
pixel_values,
532+
torch.ones((padding_len, pixel_values.shape[1]),
533+
device=pixel_values.device) * constant_value
534+
])
535+
536+
image_grid_thw = torch.tensor([[1, h_orig + pad_h, w_orig + pad_w]],
537+
device=image_grid_thw.device,
538+
dtype=image_grid_thw.dtype)
539+
540+
assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels
541+
return pixel_values, image_grid_thw
542+
461543
def _process_image_input(
462544
self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:
463545
image_patches_flat = image_input["flat_data"]
464-
patches_per_image = image_input["patches_per_image"]
465546
indicator_tokens = image_input["indicator_tokens"]
466547
grid_thws = image_input["grids"]
467548

468-
indicator_per_image = list(
469-
map(lambda x: 2 if x > 1 else x + 2, patches_per_image))
470-
471549
target_dtype = self.visual_tokenizer.dtype
472-
visual_tokens = self.visual_tokenizer(
473-
image_patches_flat.to(target_dtype), grid_thws)
474550

551+
visual_embeds, grid_thws = self.pad_multimodal_data(
552+
image_patches_flat.to(target_dtype), grid_thws,
553+
self.vision_buckets)
554+
555+
visual_tokens = self.visual_tokenizer(visual_embeds, grid_thws)
475556
visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq.
476557
indicator_embeds = self.vte(indicator_tokens)
558+
padded_patches_per_image = [
559+
grid[1] * grid[2] // (self.config.vit_config.hidden_stride**2)
560+
for grid in grid_thws
561+
]
477562

478-
visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)
563+
visual_embeds_per_image = visual_embeds.split(padded_patches_per_image,
564+
dim=0)
565+
indicator_per_image = list(
566+
map(lambda x: 2 if x > 1 else x + 2, padded_patches_per_image))
479567
indicator_embeds_per_image = indicator_embeds.split(
480568
indicator_per_image)
481569

482570
vision_embeddings = []
483-
for indicator, visual in zip(indicator_embeds_per_image,
484-
visual_embeds_per_image):
571+
for idx, (indicator, visual) in enumerate(
572+
zip(indicator_embeds_per_image, visual_embeds_per_image)):
485573
vision_embeddings_per_image = []
486574
visual = visual.unsqueeze(0)
487575
for i in range(visual.shape[0]):

vllm/worker/hpu_model_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(self, is_batch_based):
123123
multimodal_buckets = [1, 2, 4, 8] # batch sizes for gemma3
124124
else:
125125
multimodal_buckets = [
126-
1600, 3136, 4096, 6400, 7744, 9216, 12544
126+
784, 1600, 3136, 4096, 6400, 7744, 9216, 12544
127127
]
128128
else:
129129
multimodal_buckets = [int(i) for i in envvar.split(',')]

0 commit comments

Comments
 (0)