Skip to content

Commit 22128e5

Browse files
committed
Add mediapipe changes more
1 parent 92e8db3 commit 22128e5

File tree

1 file changed

+122
-68
lines changed

1 file changed

+122
-68
lines changed

vllm/model_executor/models/internvl.py

Lines changed: 122 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import numpy.typing as npt
1616
import torch
1717
import torch.nn as nn
18+
import torch.nn.functional as F
1819
import torchvision.transforms as T
1920
from PIL import Image
2021
from transformers import BatchEncoding, PretrainedConfig, TensorType
@@ -31,6 +32,8 @@
3132
from queue import Queue
3233
import io
3334
import time
35+
import atexit
36+
from dataclasses import dataclass
3437

3538
from vllm.config import VllmConfig
3639
from vllm.model_executor.layers.quantization import QuantizationConfig
@@ -353,9 +356,12 @@ def __next__(self):
353356
img_list = shared_q.get()
354357
for i in range(len(img_list)):
355358
# NOTE: this padding is needed because of HW alignmnet requirment
356-
img_list[i] = np.pad(img_list[i],
357-
(0, 64 - len(img_list[i]) % 64),
358-
'constant')
359+
rem = len(img_list[i]) % 64
360+
pad = (64 - rem) % 64
361+
if pad:
362+
img_list[i] = np.pad(img_list[i],
363+
(0, pad),
364+
'constant')
359365
return img_list
360366

361367
def get_media_output_type(self):
@@ -395,15 +401,16 @@ def __init__(self, device, queue_depth, batch_size,
395401
device="hpu", output_format=it.RGB_I, resize=[img_width, img_height])
396402

397403
self.mean_node = fn.MediaConst(
398-
data=np.array([127.5, 127.5, 127.5], dtype=dt.FLOAT32),
399-
shape=[1, 1, 3],
400-
dtype=dt.FLOAT32
401-
)
404+
data=np.array([IMAGENET_MEAN[0]*255.0,
405+
IMAGENET_MEAN[1]*255.0,
406+
IMAGENET_MEAN[2]*255.0], dtype=dt.FLOAT32),
407+
shape=[1, 1, 3], dtype=dt.FLOAT32)
408+
402409
self.std_node = fn.MediaConst(
403-
data=np.array([1/127.5, 1/127.5, 1/127.5], dtype=dt.FLOAT32),
404-
shape=[1, 1, 3],
405-
dtype=dt.FLOAT32
406-
)
410+
data=np.array([1.0/(IMAGENET_STD[0]*255.0),
411+
1.0/(IMAGENET_STD[1]*255.0),
412+
1.0/(IMAGENET_STD[2]*255.0)], dtype=dt.FLOAT32),
413+
shape=[1, 1, 3], dtype=dt.FLOAT32)
407414

408415
self.cmn = fn.CropMirrorNorm(crop_w=img_width, crop_h=img_height, dtype=dt.FLOAT32, device="hpu")
409416

@@ -424,6 +431,61 @@ def definegraph(self):
424431
# Return the full processed image - we'll do tiling in Python
425432
return images
426433

434+
435+
# -----------------------------------------------------------------------------
436+
# MediaPipe manager (persist pipes/iterators)
437+
# -----------------------------------------------------------------------------
438+
@dataclass
439+
class _PipeState:
440+
pipe: hpuMediaPipe | None = None
441+
it: MediaGenericPytorchIterator | None = None
442+
bsz: int | None = None
443+
H: int | None = None
444+
W: int | None = None
445+
446+
447+
class MediaPipeTiler:
448+
"""Owns and reuses MediaPipe pipes/iterators for main path"""
449+
def __init__(self) -> None:
450+
self._main = _PipeState()
451+
452+
def _rebuild(self, st: _PipeState, *, bsz: int, H: int, W: int) -> None:
453+
if st.pipe is not None:
454+
try:
455+
st.pipe.close()
456+
except Exception:
457+
pass
458+
pipe = hpuMediaPipe("legacy", 0, bsz, 1, "cpu", H, W)
459+
pipe.build()
460+
st.pipe, st.it, st.bsz, st.H, st.W = pipe, iter(MediaPytorchIterator(pipe)), bsz, H, W
461+
462+
def ensure_main(self, *, bsz: int, H: int, W: int) -> tuple[hpuMediaPipe, MediaGenericPytorchIterator]:
463+
st = self._main
464+
if st.pipe is None or st.bsz != bsz or st.H != H or st.W != W:
465+
self._rebuild(st, bsz=bsz, H=H, W=W)
466+
return st.pipe, st.it # type: ignore[return-value]
467+
468+
def reset_iter(self) -> None:
469+
st = self._main
470+
if st.pipe is not None:
471+
st.it = iter(MediaPytorchIterator(st.pipe))
472+
473+
def close_all(self) -> None:
474+
st = self._main
475+
try:
476+
if st.pipe is not None:
477+
st.pipe.close()
478+
except Exception:
479+
pass
480+
finally:
481+
st.pipe = None
482+
st.it = None
483+
st.bsz = st.H = st.W = None
484+
485+
486+
_MP = MediaPipeTiler()
487+
atexit.register(_MP.close_all)
488+
427489
def get_image_info(data):
428490
# Get image info using PIL without decoding
429491
try:
@@ -466,31 +528,9 @@ def preprocess_images(
466528
img_sizes.append(img_size)
467529
batch_sizes.append(batch_size)
468530

469-
thumbs = None
470-
if use_thumbnail and len(images) > 0:
471-
batch_size = len(images)
472-
pipe = hpuMediaPipe("legacy", queue_depth, batch_size,
473-
num_threads, "cpu",
474-
patch_size, patch_size)
475-
pipe.build()
476-
data_loader = MediaPytorchIterator(pipe)
477-
data_loader = iter(data_loader)
478-
479-
img_list = np.empty(shape=[batch_size, ], dtype=object)
480-
for i in range(batch_size):
481-
img_list[i] = np.frombuffer(images[i], np.uint8)
482-
483-
shared_q.put(img_list)
484-
thumbs = next(data_loader)[0]
485-
486-
shared_q.task_done()
487-
pipe.close()
488-
del pipe
489-
490531
image_num_patches = torch.zeros(len(images), dtype=torch.int64)
491532

492-
patches = []
493-
thumb_idx = 0
533+
batch_patches = []
494534
image_num_patches_idx = 0
495535
for batch_size, img_size in zip(batch_sizes, img_sizes):
496536
# calculate the number of blocks without thumbnail
@@ -502,48 +542,62 @@ def preprocess_images(
502542
use_thumbnail=False,
503543
)
504544

505-
num_patches = blocks + 1 if use_thumbnail and thumbs is not None and blocks > 1 else blocks
506-
image_num_patches[image_num_patches_idx:image_num_patches_idx+batch_size] = num_patches
507-
508-
pipe = hpuMediaPipe("legacy", queue_depth, batch_size,
509-
num_threads, "cpu",
510-
target_height, target_width)
511-
pipe.build()
512-
data_loader = MediaPytorchIterator(pipe)
513-
data_loader = iter(data_loader)
545+
# if batch, H, W is changed, create new one
546+
main_pipe, main_iter = _MP.ensure_main(bsz=batch_size, H=target_height, W=target_width)
514547

515548
img_list = np.empty(shape=[batch_size, ], dtype=object)
516549
for i in range(batch_size):
517550
img_list[i] = np.frombuffer(images[i], np.uint8)
518551

519552
shared_q.put(img_list)
520-
processed_images = next(data_loader)[0]
521-
522-
shared_q.task_done()
523-
pipe.close()
524-
del pipe
525-
526-
# Extract tiles
527-
tiles = []
528-
H, W = target_height, target_width
529-
for h_idx in range(H // patch_size):
530-
for w_idx in range(W // patch_size):
531-
h_start = h_idx * patch_size
532-
h_end = h_start + patch_size
533-
w_start = w_idx * patch_size
534-
w_end = w_start + patch_size
535-
536-
tile = processed_images[:, :, h_start:h_end, w_start:w_end]
537-
tiles.append(tile)
553+
try:
554+
processed_images = next(main_iter)[0]
555+
except StopIteration:
556+
_MP.reset_iter()
557+
_, main_iter = _MP.ensure_main(bsz=batch_size, H=target_height, W=target_width)
558+
processed_images = next(main_iter)[0]
559+
finally:
560+
shared_q.task_done()
561+
562+
# tiling vectorization: [N,C,H,W] -> [N,Ty,Tx,C,ps,ps] -> [N, T, C, ps, ps]
563+
N, C, H, W = processed_images.shape
564+
Ty, Tx = H // patch_size, W // patch_size
565+
T = Ty * Tx
566+
567+
use_thumb_now = use_thumbnail and (T > 1)
568+
569+
x = processed_images.view(N, C, Ty, patch_size, Tx, patch_size) \
570+
.permute(0, 2, 4, 1, 3, 5) \
571+
.contiguous() \
572+
.view(N, T, C, patch_size, patch_size) # [N,T,C,ps,ps]
573+
574+
if use_thumb_now:
575+
# [N,3,ps,ps]
576+
thumbs_batch = F.interpolate(processed_images, size=(patch_size, patch_size),
577+
mode="bilinear", align_corners=False)
578+
579+
num_patches = T + 1 if use_thumb_now else T
580+
image_num_patches[image_num_patches_idx:image_num_patches_idx+batch_size] = num_patches
581+
image_num_patches_idx += batch_size
582+
583+
# consist tensor batch based (tile + thumbnail)
584+
if use_thumb_now:
585+
out = torch.empty((N, T + 1, C, patch_size, patch_size),
586+
dtype=processed_images.dtype, device=processed_images.device)
587+
out[:, :T] = x
588+
out[:, T] = thumbs_batch
589+
out = out.view(N * (T + 1), C, patch_size, patch_size) # [N*(T+1),C,ps,ps]
590+
else:
591+
# no thumbnail (T==1 or off)
592+
out = x.view(N * T, C, patch_size, patch_size) # [N*T,C,ps,ps]
538593

539-
for i in range(batch_size):
540-
for t in tiles:
541-
patches.append(t[i])
542-
if use_thumbnail and thumbs is not None and len(tiles) > 1:
543-
patches.append(thumbs[thumb_idx])
544-
thumb_idx += 1
594+
batch_patches.append(out)
545595

546-
patches_flat = torch.stack(patches, dim=0)
596+
patches_flat = (
597+
torch.cat(batch_patches, dim=0)
598+
if batch_patches
599+
else torch.empty((0, 3, patch_size, patch_size), dtype=torch.float32)
600+
)
547601
return patches_flat, image_num_patches
548602

549603

0 commit comments

Comments
 (0)