Skip to content

Commit 906008d

Browse files
[OpenVINO] Add cache_position input inside prepare_inputs method for Mamba (#1517)
1 parent 916da6a commit 906008d

File tree

1 file changed

+21
-3
lines changed

1 file changed

+21
-3
lines changed

optimum/intel/openvino/modeling_decoder.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,15 +1295,20 @@ def _has_cache_inputs(model: openvino.Model) -> bool:
12951295
"past_key_values" in key.get_any_name() or "cache_params" in key.get_any_name() for key in model.inputs
12961296
)
12971297

1298-
def forward(
1298+
def prepare_inputs(
12991299
self,
1300-
input_ids: Optional[torch.LongTensor] = None,
1300+
input_ids: torch.LongTensor,
13011301
attention_mask: Optional[torch.LongTensor] = None,
13021302
cache_params=None,
13031303
use_cache: Optional[bool] = None,
13041304
cache_position: Optional[torch.Tensor] = None,
13051305
**kwargs,
1306-
):
1306+
) -> Dict:
1307+
if kwargs.get("past_key_values") is not None:
1308+
raise ValueError("`past_key_values` input is not supported for `OVModelWithMambaForCausalLM`")
1309+
if kwargs.get("position_ids") is not None:
1310+
raise ValueError("`position_ids` input is not supported for `OVModelWithMambaForCausalLM`")
1311+
13071312
inputs = {"input_ids": input_ids}
13081313
if "cache_position" in self.input_names:
13091314
if cache_position is None:
@@ -1340,6 +1345,19 @@ def forward(
13401345
batch_size = input_ids.shape[0]
13411346
inputs["beam_idx"] = np.arange(batch_size, dtype=int)
13421347

1348+
return inputs
1349+
1350+
def forward(
1351+
self,
1352+
input_ids: Optional[torch.LongTensor] = None,
1353+
attention_mask: Optional[torch.LongTensor] = None,
1354+
cache_params=None,
1355+
use_cache: Optional[bool] = None,
1356+
cache_position: Optional[torch.Tensor] = None,
1357+
**kwargs,
1358+
):
1359+
inputs = self.prepare_inputs(input_ids, attention_mask, cache_params, use_cache, cache_position, **kwargs)
1360+
13431361
self.request.start_async(inputs, share_inputs=True)
13441362
self.request.wait()
13451363
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)

0 commit comments

Comments
 (0)