|
10 | 10 | from transformers import BaseImageProcessor, BatchFeature, PretrainedConfig |
11 | 11 |
|
12 | 12 | from vllm.config import VllmConfig |
| 13 | +from vllm.logger import init_logger |
13 | 14 | from vllm.model_executor.layers.linear import ReplicatedLinear |
14 | 15 | from vllm.model_executor.layers.quantization.base_config import ( |
15 | 16 | QuantizationConfig) |
|
29 | 30 | from vllm.multimodal.profiling import BaseDummyInputsBuilder |
30 | 31 | from vllm.sequence import IntermediateTensors |
31 | 32 | from vllm.transformers_utils.processors.ovis2_5 import Ovis2_5Processor |
| 33 | +from vllm.worker.hpu_model_runner import VisionBuckets |
32 | 34 |
|
33 | 35 | from .interfaces import MultiModalEmbeddings, SupportsMultiModal |
34 | 36 |
|
| 37 | +logger = init_logger(__name__) |
| 38 | + |
35 | 39 | IMAGE_TOKEN = "<image>" |
36 | 40 | VIDEO_TOKEN = "<video>" |
37 | 41 | INDICATOR_IDS = [-301, -302, -303, -304] |
@@ -416,6 +420,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
416 | 420 |
|
417 | 421 | text_model_type = self.config.get_text_config().model_type |
418 | 422 | self.image_pad_token_id = IMAGE_PAD_TOKEN_ID_MAP[text_model_type] |
| 423 | + self.vision_buckets = VisionBuckets(is_batch_based=False) |
419 | 424 |
|
420 | 425 | # TODO(Isotr0py): PP support |
421 | 426 | # self.make_empty_intermediate_tensors = ( |
@@ -458,30 +463,113 @@ def _parse_and_validate_visual_input( |
458 | 463 |
|
459 | 464 | raise AssertionError("This line should be unreachable.") |
460 | 465 |
|
| 466 | + def find_factor(self, desired_patches, orig): |
| 467 | + for i in range(orig + 1, desired_patches + 1): |
| 468 | + if desired_patches % i == 0: |
| 469 | + if i % 2 != 0: |
| 470 | + continue |
| 471 | + else: |
| 472 | + return i |
| 473 | + return None |
| 474 | + |
| 475 | + def find_padding(self, h_orig, w_orig, desired_patches): |
| 476 | + best_pad_h, best_pad_w = 0, 0 |
| 477 | + if desired_patches % h_orig == 0: |
| 478 | + best_pad_h = 0 |
| 479 | + w_factor = desired_patches // h_orig |
| 480 | + best_pad_w = w_factor - w_orig if (w_factor > w_orig |
| 481 | + and w_factor % 2 == 0) else 0 |
| 482 | + elif desired_patches % w_orig == 0: |
| 483 | + best_pad_w = 0 |
| 484 | + h_factor = desired_patches // w_orig |
| 485 | + best_pad_h = h_factor - h_orig if (h_factor > h_orig |
| 486 | + and h_factor % 2 == 0) else 0 |
| 487 | + elif desired_patches % h_orig != 0 and desired_patches % w_orig != 0: |
| 488 | + if h_orig > w_orig: |
| 489 | + w_factor = self.find_factor(desired_patches, w_orig) |
| 490 | + if w_factor is not None: |
| 491 | + best_pad_w = w_factor - w_orig |
| 492 | + h_factor = desired_patches // w_factor |
| 493 | + if h_factor > h_orig: |
| 494 | + best_pad_h = h_factor - h_orig |
| 495 | + else: |
| 496 | + h_factor = self.find_factor(desired_patches, h_orig) |
| 497 | + if h_factor is not None: |
| 498 | + best_pad_h = h_factor - h_orig |
| 499 | + w_factor = desired_patches // h_factor |
| 500 | + if w_factor > w_orig: |
| 501 | + best_pad_w = w_factor - w_orig |
| 502 | + |
| 503 | + if (best_pad_h + h_orig) * (best_pad_w + w_orig) != desired_patches: |
| 504 | + best_pad_h, best_pad_w = 0, 0 |
| 505 | + |
| 506 | + return best_pad_h, best_pad_w |
| 507 | + |
| 508 | + def pad_multimodal_data(self, pixel_values, image_grid_thw, |
| 509 | + vision_buckets): |
| 510 | + desired_number_of_pixels = vision_buckets.get_multimodal_bucket( |
| 511 | + pixel_values.shape[0]) |
| 512 | + padding_len = desired_number_of_pixels - pixel_values.shape[0] |
| 513 | + if padding_len <= 0: |
| 514 | + return pixel_values, image_grid_thw |
| 515 | + |
| 516 | + logger_msg = "Padding current number pixel " \ |
| 517 | + + str(pixel_values.shape[0]) \ |
| 518 | + + " to " \ |
| 519 | + + str(desired_number_of_pixels) |
| 520 | + logger.info(logger_msg) |
| 521 | + |
| 522 | + h_orig, w_orig = image_grid_thw[0, 1].item(), image_grid_thw[0, |
| 523 | + 2].item() |
| 524 | + pad_h, pad_w = self.find_padding(h_orig, w_orig, |
| 525 | + desired_number_of_pixels) |
| 526 | + if pad_h == 0 and pad_w == 0: |
| 527 | + return pixel_values, image_grid_thw |
| 528 | + |
| 529 | + constant_value = -100 |
| 530 | + pixel_values = torch.cat([ |
| 531 | + pixel_values, |
| 532 | + torch.ones((padding_len, pixel_values.shape[1]), |
| 533 | + device=pixel_values.device) * constant_value |
| 534 | + ]) |
| 535 | + |
| 536 | + image_grid_thw = torch.tensor([[1, h_orig + pad_h, w_orig + pad_w]], |
| 537 | + device=image_grid_thw.device, |
| 538 | + dtype=image_grid_thw.dtype) |
| 539 | + |
| 540 | + assert image_grid_thw.prod(-1).sum() == desired_number_of_pixels |
| 541 | + return pixel_values, image_grid_thw |
| 542 | + |
461 | 543 | def _process_image_input( |
462 | 544 | self, image_input: OvisImagePatchInputs) -> MultiModalEmbeddings: |
463 | 545 | image_patches_flat = image_input["flat_data"] |
464 | | - patches_per_image = image_input["patches_per_image"] |
465 | 546 | indicator_tokens = image_input["indicator_tokens"] |
466 | 547 | grid_thws = image_input["grids"] |
467 | 548 |
|
468 | | - indicator_per_image = list( |
469 | | - map(lambda x: 2 if x > 1 else x + 2, patches_per_image)) |
470 | | - |
471 | 549 | target_dtype = self.visual_tokenizer.dtype |
472 | | - visual_tokens = self.visual_tokenizer( |
473 | | - image_patches_flat.to(target_dtype), grid_thws) |
474 | 550 |
|
| 551 | + visual_embeds, grid_thws = self.pad_multimodal_data( |
| 552 | + image_patches_flat.to(target_dtype), grid_thws, |
| 553 | + self.vision_buckets) |
| 554 | + |
| 555 | + visual_tokens = self.visual_tokenizer(visual_embeds, grid_thws) |
475 | 556 | visual_embeds = self.vte(visual_tokens) # 1:1 numeric eq. |
476 | 557 | indicator_embeds = self.vte(indicator_tokens) |
| 558 | + padded_patches_per_image = [ |
| 559 | + grid[1] * grid[2] // (self.config.vit_config.hidden_stride**2) |
| 560 | + for grid in grid_thws |
| 561 | + ] |
477 | 562 |
|
478 | | - visual_embeds_per_image = visual_embeds.split(patches_per_image, dim=0) |
| 563 | + visual_embeds_per_image = visual_embeds.split(padded_patches_per_image, |
| 564 | + dim=0) |
| 565 | + indicator_per_image = list( |
| 566 | + map(lambda x: 2 if x > 1 else x + 2, padded_patches_per_image)) |
479 | 567 | indicator_embeds_per_image = indicator_embeds.split( |
480 | 568 | indicator_per_image) |
481 | 569 |
|
482 | 570 | vision_embeddings = [] |
483 | | - for indicator, visual in zip(indicator_embeds_per_image, |
484 | | - visual_embeds_per_image): |
| 571 | + for idx, (indicator, visual) in enumerate( |
| 572 | + zip(indicator_embeds_per_image, visual_embeds_per_image)): |
485 | 573 | vision_embeddings_per_image = [] |
486 | 574 | visual = visual.unsqueeze(0) |
487 | 575 | for i in range(visual.shape[0]): |
|
0 commit comments