@@ -1295,15 +1295,20 @@ def _has_cache_inputs(model: openvino.Model) -> bool:
12951295 "past_key_values" in key .get_any_name () or "cache_params" in key .get_any_name () for key in model .inputs
12961296 )
12971297
1298- def forward (
1298+ def prepare_inputs (
12991299 self ,
1300- input_ids : Optional [ torch .LongTensor ] = None ,
1300+ input_ids : torch .LongTensor ,
13011301 attention_mask : Optional [torch .LongTensor ] = None ,
13021302 cache_params = None ,
13031303 use_cache : Optional [bool ] = None ,
13041304 cache_position : Optional [torch .Tensor ] = None ,
13051305 ** kwargs ,
1306- ):
1306+ ) -> Dict :
1307+ if kwargs .get ("past_key_values" ) is not None :
1308+ raise ValueError ("`past_key_values` input is not supported for `OVModelWithMambaForCausalLM`" )
1309+ if kwargs .get ("position_ids" ) is not None :
1310+ raise ValueError ("`position_ids` input is not supported for `OVModelWithMambaForCausalLM`" )
1311+
13071312 inputs = {"input_ids" : input_ids }
13081313 if "cache_position" in self .input_names :
13091314 if cache_position is None :
@@ -1340,6 +1345,19 @@ def forward(
13401345 batch_size = input_ids .shape [0 ]
13411346 inputs ["beam_idx" ] = np .arange (batch_size , dtype = int )
13421347
1348+ return inputs
1349+
1350+ def forward (
1351+ self ,
1352+ input_ids : Optional [torch .LongTensor ] = None ,
1353+ attention_mask : Optional [torch .LongTensor ] = None ,
1354+ cache_params = None ,
1355+ use_cache : Optional [bool ] = None ,
1356+ cache_position : Optional [torch .Tensor ] = None ,
1357+ ** kwargs ,
1358+ ):
1359+ inputs = self .prepare_inputs (input_ids , attention_mask , cache_params , use_cache , cache_position , ** kwargs )
1360+
13431361 self .request .start_async (inputs , share_inputs = True )
13441362 self .request .wait ()
13451363 logits = torch .from_numpy (self .request .get_tensor ("logits" ).data ).to (self .device )
0 commit comments