diff --git a/optimum/intel/openvino/modeling_decoder.py b/optimum/intel/openvino/modeling_decoder.py index 1dd6434641..33e9a31690 100644 --- a/optimum/intel/openvino/modeling_decoder.py +++ b/optimum/intel/openvino/modeling_decoder.py @@ -1295,15 +1295,20 @@ def _has_cache_inputs(model: openvino.Model) -> bool: "past_key_values" in key.get_any_name() or "cache_params" in key.get_any_name() for key in model.inputs ) - def forward( + def prepare_inputs( self, - input_ids: Optional[torch.LongTensor] = None, + input_ids: torch.LongTensor, attention_mask: Optional[torch.LongTensor] = None, cache_params=None, use_cache: Optional[bool] = None, cache_position: Optional[torch.Tensor] = None, **kwargs, - ): + ) -> Dict: + if kwargs.get("past_key_values") is not None: + raise ValueError("`past_key_values` input is not supported for `OVModelWithMambaForCausalLM`") + if kwargs.get("position_ids") is not None: + raise ValueError("`position_ids` input is not supported for `OVModelWithMambaForCausalLM`") + inputs = {"input_ids": input_ids} if "cache_position" in self.input_names: if cache_position is None: @@ -1340,6 +1345,19 @@ def forward( batch_size = input_ids.shape[0] inputs["beam_idx"] = np.arange(batch_size, dtype=int) + return inputs + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + cache_params=None, + use_cache: Optional[bool] = None, + cache_position: Optional[torch.Tensor] = None, + **kwargs, + ): + inputs = self.prepare_inputs(input_ids, attention_mask, cache_params, use_cache, cache_position, **kwargs) + self.request.start_async(inputs, share_inputs=True) self.request.wait() logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)