| 
10 | 10 | from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig  | 
11 | 11 | 
 
  | 
12 | 12 | from vllm.config import VllmConfig  | 
 | 13 | +from vllm.logger import init_logger  | 
 | 14 | + | 
13 | 15 | from vllm.model_executor.layers.linear import ReplicatedLinear  | 
14 | 16 | from vllm.model_executor.layers.quantization.base_config import (  | 
15 | 17 |     QuantizationConfig)  | 
 | 
29 | 31 | from vllm.multimodal.profiling import BaseDummyInputsBuilder  | 
30 | 32 | from vllm.sequence import IntermediateTensors  | 
31 | 33 | from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor  | 
32 |  | - | 
 | 34 | +from vllm.worker.hpu_model_runner import VisionBuckets  | 
33 | 35 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal  | 
34 | 36 | 
 
  | 
 | 37 | +logger = init_logger(__name__)  | 
 | 38 | + | 
35 | 39 | IMAGE_TOKEN = "<image>"  | 
36 | 40 | VIDEO_TOKEN = "<video>"  | 
37 | 41 | INDICATOR_IDS = [-301, -302, -303, -304]  | 
@@ -416,6 +420,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):  | 
416 | 420 | 
 
  | 
417 | 421 |         text_model_type = self.config.get_text_config().model_type  | 
418 | 422 |         self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type]  | 
 | 423 | +        self.vision_buckets = VisionBuckets(is_batch_based=False)  | 
419 | 424 | 
 
  | 
420 | 425 |         # TODO(Isotr0py): PP support  | 
421 | 426 |         # self.make_empty_intermediate_tensors = (  | 
@@ -458,30 +463,113 @@ def _parse_and_validate_visual_input(  | 
458 | 463 | 
 
  | 
459 | 464 |         raise AssertionError("This line should be unreachable.")  | 
460 | 465 | 
 
  | 
 | 466 | +    def find_factor(self, desired_patches, orig):  | 
 | 467 | +        for i in range(orig + 1, desired_patches + 1):  | 
 | 468 | +            if desired_patches % i == 0:  | 
 | 469 | +                if i % 2 != 0:  | 
 | 470 | +                    continue  | 
 | 471 | +                else:  | 
 | 472 | +                    return i  | 
 | 473 | +        return None  | 
 | 474 | + | 
 | 475 | +    def find_padding(self, h_orig, w_orig, desired_patches):  | 
 | 476 | +        best_pad_h, best_pad_w = 0, 0  | 
 | 477 | +        if desired_patches % h_orig == 0:  | 
 | 478 | +            best_pad_h = 0  | 
 | 479 | +            w_factor = desired_patches // h_orig  | 
 | 480 | +            best_pad_w = w_factor - w_orig if (w_factor > w_orig  | 
 | 481 | +                                               and w_factor % 2 == 0) else 0  | 
 | 482 | +        elif desired_patches % w_orig == 0:  | 
 | 483 | +            best_pad_w = 0  | 
 | 484 | +            h_factor = desired_patches // w_orig  | 
 | 485 | +            best_pad_h = h_factor - h_orig if (h_factor > h_orig  | 
 | 486 | +                                               and h_factor % 2 == 0) else 0  | 
 | 487 | +        elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0:  | 
 | 488 | +            if h_orig > w_orig:  | 
 | 489 | +                w_factor = self.find_factor(desired_patches, w_orig)  | 
 | 490 | +                if w_factor is not None:  | 
 | 491 | +                    best_pad_w = w_factor - w_orig  | 
 | 492 | +                    h_factor = desired_patches // w_factor  | 
 | 493 | +                    if h_factor > h_orig:  | 
 | 494 | +                        best_pad_h = h_factor - h_orig  | 
 | 495 | +            else:  | 
 | 496 | +                h_factor = self.find_factor(desired_patches, h_orig)  | 
 | 497 | +                if h_factor is not None:  | 
 | 498 | +                    best_pad_h = h_factor - h_orig  | 
 | 499 | +                    w_factor = desired_patches // h_factor  | 
 | 500 | +                    if w_factor > w_orig:  | 
 | 501 | +                        best_pad_w = w_factor - w_orig  | 
 | 502 | + | 
 | 503 | +        if (best_pad_h + h_orig) * (best_pad_w + w_orig) != desired_patches:  | 
 | 504 | +            best_pad_h, best_pad_w = 0, 0  | 
 | 505 | + | 
 | 506 | +        return best_pad_h, best_pad_w  | 
 | 507 | + | 
 | 508 | +    def pad_multimodal_data(self, pixel_values, image_grid_thw,  | 
 | 509 | +                            vision_buckets):  | 
 | 510 | +        desired_number_of_pixels = vision_buckets.get_multimodal_bucket(  | 
 | 511 | +            pixel_values.shape[0])  | 
 | 512 | +        padding_len = desired_number_of_pixels - pixel_values.shape[0]  | 
 | 513 | +        if padding_len <= 0:  | 
 | 514 | +            return pixel_values, image_grid_thw  | 
 | 515 | + | 
 | 516 | +        logger_msg = "Padding current number pixel " \  | 
 | 517 | +            + str(pixel_values.shape[0]) \  | 
 | 518 | +            + " to " \  | 
 | 519 | +            + str(desired_number_of_pixels)  | 
 | 520 | +        logger.info(logger_msg)  | 
 | 521 | + | 
 | 522 | +        h_orig, w_orig = image_grid_thw[0, 1].item(), image_grid_thw[0,  | 
 | 523 | +                                                                     2].item()  | 
 | 524 | +        pad_h, pad_w = self.find_padding(h_orig, w_orig,  | 
 | 525 | +                                         desired_number_of_pixels)  | 
 | 526 | +        if pad_h == 0 and pad_w == 0:  | 
 | 527 | +            return pixel_values, image_grid_thw  | 
 | 528 | + | 
 | 529 | +        constant_value = -100  | 
 | 530 | +        pixel_values = torch.cat([  | 
 | 531 | +            pixel_values,  | 
 | 532 | +            torch.ones((padding_len, pixel_values.shape[1]),  | 
 | 533 | +                       device=pixel_values.device) * constant_value  | 
 | 534 | +        ])  | 
 | 535 | + | 
 | 536 | +        image_grid_thw = torch.tensor([[1, h_orig + pad_h, w_orig + pad_w]],  | 
 | 537 | +                                      device=image_grid_thw.device,  | 
 | 538 | +                                      dtype=image_grid_thw.dtype)  | 
 | 539 | + | 
 | 540 | +        assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels  | 
 | 541 | +        return pixel_values, image_grid_thw  | 
 | 542 | + | 
461 | 543 |     def _process_image_input(  | 
462 | 544 |             self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings:  | 
463 | 545 |         image_patches_flat = image_input["flat_data"]  | 
464 |  | -        patches_per_image = image_input["patches_per_image"]  | 
465 | 546 |         indicator_tokens = image_input["indicator_tokens"]  | 
466 | 547 |         grid_thws = image_input["grids"]  | 
467 | 548 | 
 
  | 
468 |  | -        indicator_per_image = list(  | 
469 |  | -            map(lambda x: 2 if x > 1 else x + 2, patches_per_image))  | 
470 |  | - | 
471 | 549 |         target_dtype = self.visual_tokenizer.dtype  | 
472 |  | -        visual_tokens = self.visual_tokenizer(  | 
473 |  | -            image_patches_flat.to(target_dtype), grid_thws)  | 
474 | 550 | 
 
  | 
 | 551 | +        visual_embeds, grid_thws = self.pad_multimodal_data(  | 
 | 552 | +            image_patches_flat.to(target_dtype), grid_thws,  | 
 | 553 | +            self.vision_buckets)  | 
 | 554 | + | 
 | 555 | +        visual_tokens = self.visual_tokenizer(visual_embeds, grid_thws)  | 
475 | 556 |         visual_embeds = self.vte(visual_tokens)  # 1:1 numeric eq.  | 
476 | 557 |         indicator_embeds = self.vte(indicator_tokens)  | 
 | 558 | +        padded_patches_per_image = [  | 
 | 559 | +            grid[1] * grid[2] // (self.config.vit_config.hidden_stride**2)  | 
 | 560 | +            for grid in grid_thws  | 
 | 561 | +        ]  | 
477 | 562 | 
 
  | 
478 |  | -        visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0)  | 
 | 563 | +        visual_embeds_per_image = visual_embeds.split(padded_patches_per_image,  | 
 | 564 | +                                                      dim=0)  | 
 | 565 | +        indicator_per_image = list(  | 
 | 566 | +            map(lambda x: 2 if x > 1 else x + 2, padded_patches_per_image))  | 
479 | 567 |         indicator_embeds_per_image = indicator_embeds.split(  | 
480 | 568 |             indicator_per_image)  | 
481 | 569 | 
 
  | 
482 | 570 |         vision_embeddings = []  | 
483 |  | -        for indicator, visual in zip(indicator_embeds_per_image,  | 
484 |  | -                                     visual_embeds_per_image):  | 
 | 571 | +        for idx, (indicator, visual) in enumerate(  | 
 | 572 | +                zip(indicator_embeds_per_image, visual_embeds_per_image)):  | 
485 | 573 |             vision_embeddings_per_image = []  | 
486 | 574 |             visual = visual.unsqueeze(0)  | 
487 | 575 |             for i in range(visual.shape[0]):  | 
 | 
0 commit comments