1515import numpy .typing as npt
1616import torch
1717import torch .nn as nn
18+ import torch .nn .functional as F
1819import torchvision .transforms as T
1920from PIL import Image
2021from transformers import BatchEncoding , PretrainedConfig , TensorType
3132from queue import Queue
3233import io
3334import time
35+ import atexit
36+ from dataclasses import dataclass
3437
3538from vllm .config import VllmConfig
3639from vllm .model_executor .layers .quantization import QuantizationConfig
@@ -353,9 +356,12 @@ def __next__(self):
353356 img_list = shared_q .get ()
354357 for i in range (len (img_list )):
355358 # NOTE: this padding is needed because of HW alignmnet requirment
356- img_list [i ] = np .pad (img_list [i ],
357- (0 , 64 - len (img_list [i ]) % 64 ),
358- 'constant' )
359+ rem = len (img_list [i ]) % 64
360+ pad = (64 - rem ) % 64
361+ if pad :
362+ img_list [i ] = np .pad (img_list [i ],
363+ (0 , pad ),
364+ 'constant' )
359365 return img_list
360366
361367 def get_media_output_type (self ):
@@ -395,15 +401,16 @@ def __init__(self, device, queue_depth, batch_size,
395401 device = "hpu" , output_format = it .RGB_I , resize = [img_width , img_height ])
396402
397403 self .mean_node = fn .MediaConst (
398- data = np .array ([127.5 , 127.5 , 127.5 ], dtype = dt .FLOAT32 ),
399- shape = [1 , 1 , 3 ],
400- dtype = dt .FLOAT32
401- )
404+ data = np .array ([IMAGENET_MEAN [0 ]* 255.0 ,
405+ IMAGENET_MEAN [1 ]* 255.0 ,
406+ IMAGENET_MEAN [2 ]* 255.0 ], dtype = dt .FLOAT32 ),
407+ shape = [1 , 1 , 3 ], dtype = dt .FLOAT32 )
408+
402409 self .std_node = fn .MediaConst (
403- data = np .array ([1 / 127.5 , 1 / 127.5 , 1 / 127.5 ], dtype = dt . FLOAT32 ),
404- shape = [ 1 , 1 , 3 ] ,
405- dtype = dt .FLOAT32
406- )
410+ data = np .array ([1.0 / ( IMAGENET_STD [ 0 ] * 255.0 ),
411+ 1.0 / ( IMAGENET_STD [ 1 ] * 255.0 ) ,
412+ 1.0 / ( IMAGENET_STD [ 2 ] * 255.0 )], dtype = dt .FLOAT32 ),
413+ shape = [ 1 , 1 , 3 ], dtype = dt . FLOAT32 )
407414
408415 self .cmn = fn .CropMirrorNorm (crop_w = img_width , crop_h = img_height , dtype = dt .FLOAT32 , device = "hpu" )
409416
@@ -424,6 +431,61 @@ def definegraph(self):
424431 # Return the full processed image - we'll do tiling in Python
425432 return images
426433
434+
435+ # -----------------------------------------------------------------------------
436+ # MediaPipe manager (persist pipes/iterators)
437+ # -----------------------------------------------------------------------------
438+ @dataclass
439+ class _PipeState :
440+ pipe : hpuMediaPipe | None = None
441+ it : MediaGenericPytorchIterator | None = None
442+ bsz : int | None = None
443+ H : int | None = None
444+ W : int | None = None
445+
446+
447+ class MediaPipeTiler :
448+ """Owns and reuses MediaPipe pipes/iterators for main path"""
449+ def __init__ (self ) -> None :
450+ self ._main = _PipeState ()
451+
452+ def _rebuild (self , st : _PipeState , * , bsz : int , H : int , W : int ) -> None :
453+ if st .pipe is not None :
454+ try :
455+ st .pipe .close ()
456+ except Exception :
457+ pass
458+ pipe = hpuMediaPipe ("legacy" , 0 , bsz , 1 , "cpu" , H , W )
459+ pipe .build ()
460+ st .pipe , st .it , st .bsz , st .H , st .W = pipe , iter (MediaPytorchIterator (pipe )), bsz , H , W
461+
462+ def ensure_main (self , * , bsz : int , H : int , W : int ) -> tuple [hpuMediaPipe , MediaGenericPytorchIterator ]:
463+ st = self ._main
464+ if st .pipe is None or st .bsz != bsz or st .H != H or st .W != W :
465+ self ._rebuild (st , bsz = bsz , H = H , W = W )
466+ return st .pipe , st .it # type: ignore[return-value]
467+
468+ def reset_iter (self ) -> None :
469+ st = self ._main
470+ if st .pipe is not None :
471+ st .it = iter (MediaPytorchIterator (st .pipe ))
472+
473+ def close_all (self ) -> None :
474+ st = self ._main
475+ try :
476+ if st .pipe is not None :
477+ st .pipe .close ()
478+ except Exception :
479+ pass
480+ finally :
481+ st .pipe = None
482+ st .it = None
483+ st .bsz = st .H = st .W = None
484+
485+
486+ _MP = MediaPipeTiler ()
487+ atexit .register (_MP .close_all )
488+
427489def get_image_info (data ):
428490 # Get image info using PIL without decoding
429491 try :
@@ -466,31 +528,9 @@ def preprocess_images(
466528 img_sizes .append (img_size )
467529 batch_sizes .append (batch_size )
468530
469- thumbs = None
470- if use_thumbnail and len (images ) > 0 :
471- batch_size = len (images )
472- pipe = hpuMediaPipe ("legacy" , queue_depth , batch_size ,
473- num_threads , "cpu" ,
474- patch_size , patch_size )
475- pipe .build ()
476- data_loader = MediaPytorchIterator (pipe )
477- data_loader = iter (data_loader )
478-
479- img_list = np .empty (shape = [batch_size , ], dtype = object )
480- for i in range (batch_size ):
481- img_list [i ] = np .frombuffer (images [i ], np .uint8 )
482-
483- shared_q .put (img_list )
484- thumbs = next (data_loader )[0 ]
485-
486- shared_q .task_done ()
487- pipe .close ()
488- del pipe
489-
490531 image_num_patches = torch .zeros (len (images ), dtype = torch .int64 )
491532
492- patches = []
493- thumb_idx = 0
533+ batch_patches = []
494534 image_num_patches_idx = 0
495535 for batch_size , img_size in zip (batch_sizes , img_sizes ):
496536 # calculate the number of blocks without thumbnail
@@ -502,48 +542,62 @@ def preprocess_images(
502542 use_thumbnail = False ,
503543 )
504544
505- num_patches = blocks + 1 if use_thumbnail and thumbs is not None and blocks > 1 else blocks
506- image_num_patches [image_num_patches_idx :image_num_patches_idx + batch_size ] = num_patches
507-
508- pipe = hpuMediaPipe ("legacy" , queue_depth , batch_size ,
509- num_threads , "cpu" ,
510- target_height , target_width )
511- pipe .build ()
512- data_loader = MediaPytorchIterator (pipe )
513- data_loader = iter (data_loader )
545+ # if batch, H, W is changed, create new one
546+ main_pipe , main_iter = _MP .ensure_main (bsz = batch_size , H = target_height , W = target_width )
514547
515548 img_list = np .empty (shape = [batch_size , ], dtype = object )
516549 for i in range (batch_size ):
517550 img_list [i ] = np .frombuffer (images [i ], np .uint8 )
518551
519552 shared_q .put (img_list )
520- processed_images = next (data_loader )[0 ]
521-
522- shared_q .task_done ()
523- pipe .close ()
524- del pipe
525-
526- # Extract tiles
527- tiles = []
528- H , W = target_height , target_width
529- for h_idx in range (H // patch_size ):
530- for w_idx in range (W // patch_size ):
531- h_start = h_idx * patch_size
532- h_end = h_start + patch_size
533- w_start = w_idx * patch_size
534- w_end = w_start + patch_size
535-
536- tile = processed_images [:, :, h_start :h_end , w_start :w_end ]
537- tiles .append (tile )
553+ try :
554+ processed_images = next (main_iter )[0 ]
555+ except StopIteration :
556+ _MP .reset_iter ()
557+ _ , main_iter = _MP .ensure_main (bsz = batch_size , H = target_height , W = target_width )
558+ processed_images = next (main_iter )[0 ]
559+ finally :
560+ shared_q .task_done ()
561+
562+ # tiling vectorization: [N,C,H,W] -> [N,Ty,Tx,C,ps,ps] -> [N, T, C, ps, ps]
563+ N , C , H , W = processed_images .shape
564+ Ty , Tx = H // patch_size , W // patch_size
565+ T = Ty * Tx
566+
567+ use_thumb_now = use_thumbnail and (T > 1 )
568+
569+ x = processed_images .view (N , C , Ty , patch_size , Tx , patch_size ) \
570+ .permute (0 , 2 , 4 , 1 , 3 , 5 ) \
571+ .contiguous () \
572+ .view (N , T , C , patch_size , patch_size ) # [N,T,C,ps,ps]
573+
574+ if use_thumb_now :
575+ # [N,3,ps,ps]
576+ thumbs_batch = F .interpolate (processed_images , size = (patch_size , patch_size ),
577+ mode = "bilinear" , align_corners = False )
578+
579+ num_patches = T + 1 if use_thumb_now else T
580+ image_num_patches [image_num_patches_idx :image_num_patches_idx + batch_size ] = num_patches
581+ image_num_patches_idx += batch_size
582+
583+ # consist tensor batch based (tile + thumbnail)
584+ if use_thumb_now :
585+ out = torch .empty ((N , T + 1 , C , patch_size , patch_size ),
586+ dtype = processed_images .dtype , device = processed_images .device )
587+ out [:, :T ] = x
588+ out [:, T ] = thumbs_batch
589+ out = out .view (N * (T + 1 ), C , patch_size , patch_size ) # [N*(T+1),C,ps,ps]
590+ else :
591+ # no thumbnail (T==1 or off)
592+ out = x .view (N * T , C , patch_size , patch_size ) # [N*T,C,ps,ps]
538593
539- for i in range (batch_size ):
540- for t in tiles :
541- patches .append (t [i ])
542- if use_thumbnail and thumbs is not None and len (tiles ) > 1 :
543- patches .append (thumbs [thumb_idx ])
544- thumb_idx += 1
594+ batch_patches .append (out )
545595
546- patches_flat = torch .stack (patches , dim = 0 )
596+ patches_flat = (
597+ torch .cat (batch_patches , dim = 0 )
598+ if batch_patches
599+ else torch .empty ((0 , 3 , patch_size , patch_size ), dtype = torch .float32 )
600+ )
547601 return patches_flat , image_num_patches
548602
549603
0 commit comments