1919from PIL import Image
2020from transformers import BatchEncoding , PretrainedConfig , TensorType
2121
22+ from habana_frameworks .mediapipe import fn
23+ from habana_frameworks .mediapipe .mediapipe import MediaPipe
24+ from habana_frameworks .mediapipe .media_types import dtype as dt
25+ from habana_frameworks .mediapipe .media_types import imgtype as it
26+ from habana_frameworks .mediapipe .media_types import readerOutType as ro
27+ from habana_frameworks .mediapipe .operators .reader_nodes .reader_nodes import media_ext_reader_op_impl
28+ from habana_frameworks .mediapipe .operators .reader_nodes .reader_nodes import media_ext_reader_op_tensor_info
29+ from habana_frameworks .mediapipe .plugins .iterator_pytorch import MediaGenericPytorchIterator
30+ import numpy as np
31+ from queue import Queue
32+ import io
33+ import time
34+
2235from vllm .config import VllmConfig
2336from vllm .model_executor .layers .quantization import QuantizationConfig
2437from vllm .model_executor .layers .quantization .awq import AWQConfig
@@ -295,6 +308,244 @@ def video_to_pixel_values_internvl(
295308 pixel_values = torch .stack ([transform (image ) for image in frames_list ])
296309 return pixel_values
297310
311+ # Handle MediaPipe pipe_manager destructor
312+ from habana_frameworks .mediapipe .backend .cal import pipe_manager , cpp_pipe_manager_list
313+
314+ def _patched_close (self ):
315+ """Patched close method that handles None cpp_pipe_manager_list during shutdown"""
316+ try :
317+ # Check if cpp_pipe_manager_list exists and is not None
318+ if cpp_pipe_manager_list is not None and self ._pm_ in cpp_pipe_manager_list :
319+ cpp_pipe_manager_list .remove (self ._pm_ )
320+ except (TypeError , AttributeError ):
321+ # Handle case where cpp_pipe_manager_list is None or not iterable
322+ pass
323+
324+ # Clean up the pipe manager
325+ if self ._pm_ is not None :
326+ self ._pm_ .close ()
327+ self ._pm_ = None
328+
329+ pipe_manager .close = _patched_close
330+
331+ # Queue shared between external reader and mediapipe call
332+ shared_q = Queue ()
333+
334+
335+ class MediaPytorchIterator (MediaGenericPytorchIterator ):
336+ def __init__ (self , mediapipe ):
337+ super ().__init__ (mediapipe = mediapipe , device = "hpu" , fw_type = "PYT_FW" )
338+
339+
340+ class external_reader (media_ext_reader_op_impl ):
341+ def __init__ (self , params , fw_params ):
342+ self .batch_size = fw_params .batch_size
343+ self .max_file = ""
344+ self .num_batches = 1
345+
346+ def __iter__ (self ):
347+ return self
348+
349+ def __len__ (self ):
350+ return self .num_batches
351+
352+ def __next__ (self ):
353+ img_list = shared_q .get ()
354+ for i in range (len (img_list )):
355+ # NOTE: this padding is needed because of HW alignmnet requirment
356+ img_list [i ] = np .pad (img_list [i ],
357+ (0 , 64 - len (img_list [i ]) % 64 ),
358+ 'constant' )
359+ return img_list
360+
361+ def get_media_output_type (self ):
362+ return ro .BUFFER_LIST
363+
364+ def get_largest_file (self ):
365+ return self .max_file
366+
367+ def gen_output_info (self ):
368+ out_info = []
369+ o = media_ext_reader_op_tensor_info (
370+ dt .NDT , np .array ([self .batch_size ], dtype = np .uint32 ), "" )
371+ out_info .append (o )
372+ return out_info
373+
374+
375+ class hpuMediaPipe (MediaPipe ):
376+ def __init__ (self , device , queue_depth , batch_size ,
377+ num_threads , op_device ,
378+ img_height , img_width ):
379+ super (
380+ hpuMediaPipe ,
381+ self ).__init__ (
382+ device ,
383+ queue_depth ,
384+ batch_size ,
385+ num_threads ,
386+ self .__class__ .__name__ )
387+
388+ mediapipe_seed = int (time .time_ns () % (2 ** 31 - 1 ))
389+
390+ self .input = fn .MediaExtReaderOp (impl = external_reader ,
391+ num_outputs = 1 ,
392+ seed = mediapipe_seed ,
393+ device = op_device )
394+ self .decode = fn .ImageDecoder (
395+ device = "hpu" , output_format = it .RGB_I , resize = [img_width , img_height ])
396+
397+ self .mean_node = fn .MediaConst (
398+ data = np .array ([127.5 , 127.5 , 127.5 ], dtype = dt .FLOAT32 ),
399+ shape = [1 , 1 , 3 ],
400+ dtype = dt .FLOAT32
401+ )
402+ self .std_node = fn .MediaConst (
403+ data = np .array ([1 / 127.5 , 1 / 127.5 , 1 / 127.5 ], dtype = dt .FLOAT32 ),
404+ shape = [1 , 1 , 3 ],
405+ dtype = dt .FLOAT32
406+ )
407+
408+ self .cmn = fn .CropMirrorNorm (crop_w = img_width , crop_h = img_height , dtype = dt .FLOAT32 , device = "hpu" )
409+
410+ self .transpose = fn .Transpose (
411+ device = "hpu" ,
412+ tensorDim = 4 ,
413+ permutation = [1 , 2 , 0 , 3 ] #NCHW
414+ )
415+
416+ def definegraph (self ):
417+ images = self .input ()
418+ images = self .decode (images )
419+ mean = self .mean_node ()
420+ std = self .std_node ()
421+ images = self .cmn (images , mean , std )
422+ images = self .transpose (images )
423+
424+ # Return the full processed image - we'll do tiling in Python
425+ return images
426+
427+ def get_image_info (data ):
428+ # Get image info using PIL without decoding
429+ try :
430+ with Image .open (io .BytesIO (data )) as img :
431+ return {
432+ 'format' : img .format ,
433+ 'size' : img .size ,
434+ 'mode' : img .mode
435+ }
436+ except Exception as e :
437+ raise ValueError (f"Input image bitstream is not in supported format: { str (e )} " )
438+
439+ def preprocess_images (
440+ images ,
441+ target_ratios : list [tuple [int , int ]],
442+ patch_size = 448 ,
443+ use_thumbnail = False ,
444+ ):
445+ batch_size = 0
446+ queue_depth = 0
447+ num_threads = 1
448+
449+ # validate images and create batches
450+ img_size = None
451+ img_sizes = []
452+ batch_sizes = []
453+ for img in images :
454+ img_info = get_image_info (img )
455+ if img_info ['format' ] != 'JPEG' and img_info ['mode' ] != 'RGB' :
456+ raise ValueError (f"HPU media pipeline only supports JPEG images in RGB mode. Detected format={ img_info ['format' ]} , mode={ img_info ['mode' ]} " )
457+ if img_size == None :
458+ img_size = img_info ['size' ]
459+ else :
460+ if img_info ['size' ] != img_size :
461+ img_sizes .append (img_size )
462+ batch_sizes .append (batch_size )
463+ batch_size = 0
464+ img_size = img_info ['size' ]
465+ batch_size += 1
466+ img_sizes .append (img_size )
467+ batch_sizes .append (batch_size )
468+
469+ thumbs = None
470+ if use_thumbnail and len (images ) > 0 :
471+ batch_size = len (images )
472+ pipe = hpuMediaPipe ("legacy" , queue_depth , batch_size ,
473+ num_threads , "cpu" ,
474+ patch_size , patch_size )
475+ pipe .build ()
476+ data_loader = MediaPytorchIterator (pipe )
477+ data_loader = iter (data_loader )
478+
479+ img_list = np .empty (shape = [batch_size , ], dtype = object )
480+ for i in range (batch_size ):
481+ img_list [i ] = np .frombuffer (images [i ], np .uint8 )
482+
483+ shared_q .put (img_list )
484+ thumbs = next (data_loader )[0 ]
485+
486+ shared_q .task_done ()
487+ pipe .close ()
488+ del pipe
489+
490+ image_num_patches = torch .zeros (len (images ), dtype = torch .int64 )
491+
492+ patches = []
493+ thumb_idx = 0
494+ image_num_patches_idx = 0
495+ for batch_size , img_size in zip (batch_sizes , img_sizes ):
496+ # calculate the number of blocks without thumbnail
497+ blocks , target_width , target_height = calculate_internvl_targets (
498+ orig_width = img_size [0 ],
499+ orig_height = img_size [1 ],
500+ target_ratios = target_ratios ,
501+ image_size = patch_size ,
502+ use_thumbnail = False ,
503+ )
504+
505+ num_patches = blocks + 1 if use_thumbnail and thumbs is not None and blocks > 1 else blocks
506+ image_num_patches [image_num_patches_idx :image_num_patches_idx + batch_size ] = num_patches
507+
508+ pipe = hpuMediaPipe ("legacy" , queue_depth , batch_size ,
509+ num_threads , "cpu" ,
510+ target_height , target_width )
511+ pipe .build ()
512+ data_loader = MediaPytorchIterator (pipe )
513+ data_loader = iter (data_loader )
514+
515+ img_list = np .empty (shape = [batch_size , ], dtype = object )
516+ for i in range (batch_size ):
517+ img_list [i ] = np .frombuffer (images [i ], np .uint8 )
518+
519+ shared_q .put (img_list )
520+ processed_images = next (data_loader )[0 ]
521+
522+ shared_q .task_done ()
523+ pipe .close ()
524+ del pipe
525+
526+ # Extract tiles
527+ tiles = []
528+ H , W = target_height , target_width
529+ for h_idx in range (H // patch_size ):
530+ for w_idx in range (W // patch_size ):
531+ h_start = h_idx * patch_size
532+ h_end = h_start + patch_size
533+ w_start = w_idx * patch_size
534+ w_end = w_start + patch_size
535+
536+ tile = processed_images [:, :, h_start :h_end , w_start :w_end ]
537+ tiles .append (tile )
538+
539+ for i in range (batch_size ):
540+ for t in tiles :
541+ patches .append (t [i ])
542+ if use_thumbnail and thumbs is not None and len (tiles ) > 1 :
543+ patches .append (thumbs [thumb_idx ])
544+ thumb_idx += 1
545+
546+ patches_flat = torch .stack (patches , dim = 0 )
547+ return patches_flat , image_num_patches
548+
298549
299550class BaseInternVLProcessor (ABC ):
300551 """
@@ -451,25 +702,58 @@ def _preprocess_image(
451702 if len (images ) == 0 :
452703 image_inputs = {}
453704 else :
454- pixel_values_lst = self ._images_to_pixel_values_lst (
455- images ,
456- min_dynamic_patch = min_dynamic_patch ,
457- max_dynamic_patch = max_dynamic_patch ,
458- dynamic_image_size = dynamic_image_size ,
459- )
460- image_inputs : dict [str , NestedTensors ] = {
461- "pixel_values_flat" :
462- torch .cat (pixel_values_lst ),
463- "image_num_patches" :
464- torch .tensor ([len (item ) for item in pixel_values_lst ]),
465- }
705+ use_mediapipe = os .getenv ("VLLM_USE_MEDIA_PIPELINE" , "false" ).lower () in ("1" , "true" , "yes" )
706+ if use_mediapipe :
707+ # Use HPU media pipeline for image preprocessing
708+ min_num , max_num = self .resolve_min_max_num (
709+ min_dynamic_patch = min_dynamic_patch ,
710+ max_dynamic_patch = max_dynamic_patch ,
711+ dynamic_image_size = dynamic_image_size ,
712+ use_thumbnail = False , # Applied in image_to_pixel_values
713+ )
466714
467- for pixel_values in pixel_values_lst :
468- num_patches = pixel_values .shape [0 ]
469- feature_size = num_patches * self .num_image_token
715+ target_ratios = get_internvl_target_ratios (min_num , max_num )
716+
717+ pixel_values_flat , image_num_patches = preprocess_images (
718+ images ,
719+ target_ratios = target_ratios ,
720+ patch_size = self .image_size ,
721+ use_thumbnail = self .use_thumbnail ,
722+ )
723+
724+ image_inputs = {
725+ "pixel_values_flat" : pixel_values_flat ,
726+ "image_num_patches" : image_num_patches ,
727+ }
728+
729+ for i in range (len (images )):
730+ num_patches = image_num_patches [i ].item ()
731+ feature_size = num_patches * self .num_image_token
732+
733+ image_repl = self .get_image_repl (feature_size , num_patches )
734+ text = [t .replace ('<image>' , image_repl .full , 1 ) for t in text ]
735+
736+ else :
737+ pixel_values_lst = self ._images_to_pixel_values_lst (
738+ images ,
739+ min_dynamic_patch = min_dynamic_patch ,
740+ max_dynamic_patch = max_dynamic_patch ,
741+ dynamic_image_size = dynamic_image_size ,
742+ )
743+ image_inputs : dict [str , NestedTensors ] = {
744+ "pixel_values_flat" :
745+ torch .cat (pixel_values_lst ),
746+ "image_num_patches" :
747+ torch .tensor ([len (item ) for item in pixel_values_lst ]),
748+ }
749+
750+ for pixel_values in pixel_values_lst :
751+ num_patches = pixel_values .shape [0 ]
752+ feature_size = num_patches * self .num_image_token
753+
754+ image_repl = self .get_image_repl (feature_size , num_patches )
755+ text = [t .replace ('<image>' , image_repl .full , 1 ) for t in text ]
470756
471- image_repl = self .get_image_repl (feature_size , num_patches )
472- text = [t .replace ('<image>' , image_repl .full , 1 ) for t in text ]
473757 return text , image_inputs
474758
475759 def _make_batch_input (self ,
@@ -483,7 +767,7 @@ def _make_batch_input(self,
483767 def __call__ (
484768 self ,
485769 text : Optional [Union [str , list [str ]]] = None ,
486- images : Optional [Union [Image .Image , list [Image .Image ]]] = None ,
770+ images : Optional [Union [Image .Image , list [Image .Image ], bytes , list [ bytes ] ]] = None ,
487771 min_dynamic_patch : Optional [int ] = None ,
488772 max_dynamic_patch : Optional [int ] = None ,
489773 dynamic_image_size : Optional [bool ] = None ,
@@ -602,7 +886,7 @@ def _preprocess_video(
602886 def __call__ (
603887 self ,
604888 text : Optional [Union [str , list [str ]]] = None ,
605- images : Optional [Union [Image .Image , list [Image .Image ]]] = None ,
889+ images : Optional [Union [Image .Image , list [Image .Image ], bytes , list [ bytes ] ]] = None ,
606890 videos : Optional [Union [npt .NDArray , list [npt .NDArray ]]] = None ,
607891 min_dynamic_patch : Optional [int ] = None ,
608892 max_dynamic_patch : Optional [int ] = None ,
0 commit comments