@@ -230,7 +230,15 @@ def forward(self, image_feature, pos_embed, key_padding_mask):
230230 return result
231231
232232
233- MODEL_PARTS_CLS_MAPPING = {"resampler" : OVResampler }
233+ class OVVisionProjection (OVModelPart ):
234+ _model_name = "vision_projection"
235+
236+ def forward (self , img_features ):
237+ self ._compile ()
238+ return self .request (img_features )[0 ]
239+
240+
241+ MODEL_PARTS_CLS_MAPPING = {"resampler" : OVResampler , "vision_projection" : OVVisionProjection }
234242
235243
236244class OVModelForVisualCausalLM (OVBaseModel , GenerationMixin ):
@@ -1802,8 +1810,8 @@ def preprocess_inputs(
18021810 raise ValueError ("Tokenizer is required." )
18031811 if image is not None and processor is None :
18041812 raise ValueError ("Processor is required." )
1805- text_content = f"<image>\n { text } " if image is not None else text
1806- messages = [{"role" : "user" , "content" : text_content }]
1813+ text = f"<image>\n { text } " if image is not None else text
1814+ messages = [{"role" : "user" , "content" : text }]
18071815 if tokenizer .chat_template is not None :
18081816 text = tokenizer .apply_chat_template (messages , tokenize = False , add_generation_prompt = True )
18091817 if image is not None :
@@ -1818,10 +1826,161 @@ def preprocess_inputs(
18181826 return result
18191827
18201828
1829+ class _OVPhi3VisionForCausalLM (OVModelForVisualCausalLM ):
1830+ additional_parts = ["vision_projection" ]
1831+
1832+ def __init__ (
1833+ self ,
1834+ language_model : ov .Model ,
1835+ text_embeddings : ov .Model ,
1836+ vision_embeddings : ov .Model ,
1837+ config : PretrainedConfig = None ,
1838+ device : str = "CPU" ,
1839+ dynamic_shapes : bool = True ,
1840+ ov_config : Optional [Dict [str , str ]] = None ,
1841+ model_save_dir : Optional [Union [str , Path , TemporaryDirectory ]] = None ,
1842+ quantization_config : Union [OVWeightQuantizationConfig , Dict ] = None ,
1843+ ** kwargs ,
1844+ ):
1845+ super ().__init__ (
1846+ language_model ,
1847+ text_embeddings ,
1848+ vision_embeddings ,
1849+ config ,
1850+ device ,
1851+ dynamic_shapes ,
1852+ ov_config ,
1853+ model_save_dir ,
1854+ quantization_config ,
1855+ ** kwargs ,
1856+ )
1857+ self .sub_GN = torch .tensor (self .config .sub_GN )
1858+ self .glb_GN = torch .tensor (self .config .glb_GN )
1859+
1860+ def get_vision_embeddings (self , pixel_values , image_sizes , ** kwargs ):
1861+ num_images , num_crops , c , h , w = pixel_values .shape
1862+ img_features = self .vision_embeddings (pixel_values .flatten (0 , 1 )).last_hidden_state .reshape (
1863+ num_images , num_crops , - 1 , self .config .img_processor ["image_dim_out" ]
1864+ )
1865+ image_features_proj = self .hd_feature_transform (img_features , image_sizes )
1866+ return image_features_proj
1867+
1868+ def hd_feature_transform (self , image_features , image_sizes ):
1869+ """
1870+ image_features: (num_images, num_crops+1, 24*24, 1024)
1871+ """
1872+
1873+ image_features = torch .from_numpy (image_features )
1874+ global_image_features = image_features [:, 0 ] # (num_images, 24*24, 1024)
1875+ # global feature can be viewed as a special HD case with num_crops 1x1
1876+ global_image_features_hd = self .reshape_hd_patches_2x2merge (global_image_features , 1 , 1 )
1877+ global_image_features_hd_newline = self .add_image_newline (global_image_features_hd )
1878+
1879+ all_image_embeddings = []
1880+ # need a for loop to process each image because of different image sizes
1881+ # (patch arrangement is different for each image)
1882+ for i , img_size in enumerate (image_sizes ):
1883+ h , w = img_size
1884+ h_crop = h // 336
1885+ w_crop = w // 336
1886+ num_crops = h_crop * w_crop
1887+
1888+ # NOTE: real num_crops is padded
1889+ # (num_crops, 24*24, 1024)
1890+ sub_image_features = image_features [i , 1 : 1 + num_crops ]
1891+ sub_image_features_hd = self .reshape_hd_patches_2x2merge (sub_image_features , h_crop , w_crop )
1892+ sub_image_features_hd_newline = self .add_image_newline (sub_image_features_hd )
1893+
1894+ # [sub features, separator, global features]
1895+ all_image_embeddings .extend (
1896+ [
1897+ sub_image_features_hd_newline .squeeze (0 ), # (h_crop*12*(w_crop*12+1), 4096)
1898+ self .glb_GN .squeeze (0 ),
1899+ global_image_features_hd_newline [i ],
1900+ ]
1901+ )
1902+ image_features_proj = self .vision_projection (torch .cat (all_image_embeddings , dim = 0 ).unsqueeze (0 ))[0 ]
1903+
1904+ return image_features_proj
1905+
1906+ def reshape_hd_patches_2x2merge (self , image_features , h_crop , w_crop ):
1907+ """
1908+ image_features: (num_images*num_crops, 24*24, 1024)
1909+ output: (num_images, h_crop*12, w_crop*12, 4096), h_crop*w_crop == num_crops
1910+ """
1911+ N , L , C = image_features .shape
1912+ assert L == 24 * 24 and C == 1024 and N % (h_crop * w_crop ) == 0
1913+ num_images = N // (h_crop * w_crop )
1914+ H = int (L ** 0.5 )
1915+ image_features_hd = (
1916+ image_features .reshape (N , H , H , C ) # N, 24, 24, 1024
1917+ .reshape (N , H // 2 , 2 , H // 2 , 2 , C ) # N, 12, 2, 12, 2, 1024
1918+ .permute (0 , 1 , 3 , 2 , 4 , 5 ) # N, 12, 12, 2, 2, 1024
1919+ .reshape (N , - 1 , 4 * C ) # N, 144, 4096
1920+ .reshape (num_images , h_crop , w_crop , H // 2 , H // 2 , - 1 ) # n_img, h_crop, w_crop, 12, 12, 4096
1921+ .permute (0 , 1 , 3 , 2 , 4 , 5 ) # n_img, h_crop, 12, w_crop, 12, 4096
1922+ .reshape (num_images , h_crop * H // 2 , w_crop * H // 2 , 4 * C ) # n_img, h_crop*12, w_crop*12, 4096
1923+ )
1924+
1925+ return image_features_hd
1926+
1927+ def add_image_newline (self , image_features_hd ):
1928+ """
1929+ image_features_hd: (num_images, h_crop*12, w_crop*12, 4096)
1930+ output: (num_images, (h_crop*12) * (w_crop*12+1), 4096)
1931+ """
1932+ num_images , h , w , hid_dim = image_features_hd .shape
1933+ # add the newline token to the HD image feature patches
1934+ newline_embeddings = self .sub_GN .expand (num_images , h , - 1 , - 1 ) # (n_img, h, 1, hid_dim)
1935+ image_features_hd_newline = torch .cat ([image_features_hd , newline_embeddings ], dim = 2 ).reshape (
1936+ num_images , - 1 , hid_dim
1937+ )
1938+ return image_features_hd_newline
1939+
1940+ def get_multimodal_embeddings (
1941+ self , input_ids , pixel_values = None , attention_mask = None , position_ids = None , image_sizes = None , ** kwargs
1942+ ):
1943+ MAX_INPUT_ID = int (1e9 )
1944+ input_shape = input_ids .size ()
1945+ input_ids = input_ids .view (- 1 , input_shape [- 1 ])
1946+
1947+ # positions for image tokens
1948+ positions = torch .nonzero ((input_ids < 0 ) & (input_ids > - MAX_INPUT_ID ), as_tuple = True )
1949+ has_image = len (positions [0 ].tolist ()) > 0
1950+ input_ids = input_ids .clamp_min (0 ).clamp_max (self .config .vocab_size )
1951+ inputs_embeds = torch .from_numpy (self .get_text_embeddings (input_ids , ** kwargs ))
1952+ if has_image :
1953+ vision_embeds = self .get_vision_embeddings (
1954+ pixel_values , input_ids = input_ids , image_sizes = image_sizes , ** kwargs
1955+ )
1956+ image_features_proj = torch .from_numpy (vision_embeds )
1957+ inputs_embeds = inputs_embeds .index_put (positions , image_features_proj , accumulate = False )
1958+
1959+ return inputs_embeds , attention_mask , position_ids
1960+
1961+ @staticmethod
1962+ def preprocess_inputs (
1963+ text : str ,
1964+ image : Optional [Image ] = None ,
1965+ processor : Optional [AutoImageProcessor ] = None ,
1966+ tokenizer : Optional [PreTrainedTokenizer ] = None ,
1967+ ):
1968+ if processor is None :
1969+ raise ValueError ("Processor is required." )
1970+ if image is not None and "<|image_1|>" not in text :
1971+ text = "<|image_1|>\n " + text
1972+ if getattr (processor .tokenizer , "chat_template" , None ) is not None :
1973+ chat_prompt = [{"role" : "user" , "content" : text }]
1974+ text = processor .tokenizer .apply_chat_template (chat_prompt , add_generation_prompt = True , tokenize = False )
1975+ inputs = processor (images = image , text = text , return_tensors = "pt" )
1976+ return inputs
1977+
1978+
18211979MODEL_TYPE_TO_CLS_MAPPING = {
18221980 "llava" : _OVLlavaForCausalLM ,
18231981 "llava_next" : _OVLlavaNextForCausalLM ,
18241982 "internvl_chat" : _OvInternVLForCausalLM ,
18251983 "minicpmv" : _OVMiniCPMVForCausalLM ,
18261984 "llava-qwen2" : _OVNanoLlavaForCausalLM ,
1985+ "phi3_v" : _OVPhi3VisionForCausalLM ,
18271986}
0 commit comments