Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions optimum/intel/openvino/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1295,15 +1295,20 @@ def _has_cache_inputs(model: openvino.Model) -> bool:
"past_key_values" in key.get_any_name() or "cache_params" in key.get_any_name() for key in model.inputs
)

def forward(
def prepare_inputs(
self,
input_ids: Optional[torch.LongTensor] = None,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.LongTensor] = None,
cache_params=None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
):
) -> Dict:
if kwargs.get("past_key_values") is not None:
raise ValueError("`past_key_values` input is not supported for `OVModelWithMambaForCausalLM`")
if kwargs.get("position_ids") is not None:
raise ValueError("`position_ids` input is not supported for `OVModelWithMambaForCausalLM`")
Comment on lines +1307 to +1310
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this really needed ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My motivation was to avoid possible confusion if anyone passes these arguments because parent's definition of prepare_inputs() has these arguments in its signature.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm i'm honestly not sure why this function is public, imo it should be private (it serves one purpose, an internal one)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

@IlyasMoutawwakil IlyasMoutawwakil Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that's still internal though 😅.. what i mean is that it shouldn't be used by a user
the downside of a public method is that we need to maintain its behaviour and only change it through deprecation process (over multiple versions).


inputs = {"input_ids": input_ids}
if "cache_position" in self.input_names:
if cache_position is None:
Expand Down Expand Up @@ -1340,6 +1345,19 @@ def forward(
batch_size = input_ids.shape[0]
inputs["beam_idx"] = np.arange(batch_size, dtype=int)

return inputs

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.LongTensor] = None,
cache_params=None,
use_cache: Optional[bool] = None,
cache_position: Optional[torch.Tensor] = None,
**kwargs,
):
inputs = self.prepare_inputs(input_ids, attention_mask, cache_params, use_cache, cache_position, **kwargs)

self.request.start_async(inputs, share_inputs=True)
self.request.wait()
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
Expand Down
Loading