Skip to content

Commit 6634bbc

Browse files
authored
Fixes openvino decoder models output (#1308)
* fix inference * add test * fix
1 parent 94b2d60 commit 6634bbc

File tree

4 files changed

+24
-9
lines changed

4 files changed

+24
-9
lines changed

optimum/intel/openvino/modeling_decoder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -553,10 +553,11 @@ def forward(
553553

554554
if self._first_iter_beam_search:
555555
inputs, duplication_indices = self._deduplicate_inputs(inputs)
556+
556557
# Run inference
557558
self.request.start_async(inputs, share_inputs=True)
558559
self.request.wait()
559-
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
560+
logits = torch.from_numpy(self.request.get_tensor("logits").data).clone().to(self.device)
560561
if self.stateful:
561562
# Need a marker to differentiate the first generate iteration from the others in
562563
# the first condition at the function beginning above.
@@ -567,7 +568,9 @@ def forward(
567568
if not self.stateful:
568569
if self.use_cache:
569570
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the self-attention layer)
570-
past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
571+
past_key_values = tuple(
572+
np.copy(self.request.get_tensor(key).data) for key in self.key_value_output_names
573+
)
571574
if self.config.model_type not in MULTI_QUERY_ATTN_MODELS or (
572575
self.config.model_type == "falcon" and self.config.new_decoder_architecture
573576
):

optimum/intel/openvino/modeling_seq2seq.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -985,15 +985,17 @@ def forward(
985985
# Run inference
986986
self.request.start_async(inputs, share_inputs=True)
987987
self.request.wait()
988-
logits = torch.from_numpy(self.request.get_tensor("logits").data).to(self.device)
988+
logits = torch.from_numpy(self.request.get_tensor("logits").data).clone().to(self.device)
989989
self._past_length += input_ids.shape[1]
990990

991991
out_past_key_values = ((),)
992992

993993
if not self.stateful:
994994
# Tuple of length equal to : number of layer * number of past_key_value per decoder layer (2 corresponds to the
995995
# self-attention layer and 2 to the cross-attention layer)
996-
out_past_key_values = tuple(self.request.get_tensor(key).data for key in self.key_value_output_names)
996+
out_past_key_values = tuple(
997+
np.copy(self.request.get_tensor(key).data) for key in self.key_value_output_names
998+
)
997999

9981000
# Tuple of tuple of length `n_layers`, with each tuple of length equal to:
9991001
# * 4 for the decoder without cache (k/v of self-attention + k/v of cross-attention)

optimum/intel/openvino/modeling_visual_language.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def forward(
223223
self.request.start_async(inputs, share_inputs=True)
224224
self.request.wait()
225225
logits = self.request.get_tensor("logits").data
226-
logits = torch.from_numpy(logits).to(self.device)
226+
logits = torch.from_numpy(logits).clone().to(self.device)
227227
past_key_values = ((),)
228228
self._past_length += inputs["inputs_embeds"].shape[1]
229229

tests/openvino/test_modeling.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1450,25 +1450,35 @@ def test_compare_with_and_without_past_key_values(self):
14501450
model_id = MODEL_NAMES["gpt2"]
14511451
tokenizer = AutoTokenizer.from_pretrained(model_id)
14521452
tokens = tokenizer("This is a sample input", return_tensors="pt")
1453+
14531454
model_with_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=False)
14541455
outputs_model_with_pkv = model_with_pkv.generate(
14551456
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
14561457
)
1458+
del model_with_pkv
1459+
14571460
model_without_pkv = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=False)
14581461
outputs_model_without_pkv = model_without_pkv.generate(
14591462
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
14601463
)
1464+
del model_without_pkv
1465+
14611466
self.assertTrue(torch.equal(outputs_model_with_pkv, outputs_model_without_pkv))
14621467
self.assertEqual(outputs_model_with_pkv.shape[1], self.GENERATION_LENGTH)
14631468
self.assertEqual(outputs_model_without_pkv.shape[1], self.GENERATION_LENGTH)
1469+
14641470
model_stateful = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True, stateful=True)
14651471
outputs_model_stateful = model_stateful.generate(
14661472
**tokens, min_length=self.GENERATION_LENGTH, max_length=self.GENERATION_LENGTH, num_beams=1
14671473
)
14681474
self.assertTrue(torch.equal(outputs_model_without_pkv, outputs_model_stateful))
14691475

1470-
del model_with_pkv
1471-
del model_without_pkv
1476+
logits = model_stateful(**tokens).logits
1477+
copy_logits = copy.deepcopy(logits)
1478+
tokens = tokenizer("Input sample", return_tensors="pt")
1479+
model_stateful(**tokens).logits
1480+
self.assertTrue(torch.equal(copy_logits, logits))
1481+
del model_stateful
14721482
gc.collect()
14731483

14741484
def test_print_model_properties(self):
@@ -1496,7 +1506,7 @@ def test_auto_device_loading(self):
14961506

14971507
def test_default_filling_attention_mask(self):
14981508
model_id = MODEL_NAMES["gpt2"]
1499-
model_with_cache = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True)
1509+
model_with_cache = OVModelForCausalLM.from_pretrained(model_id, stateful=False, use_cache=True)
15001510
tokenizer = AutoTokenizer.from_pretrained(model_id)
15011511
tokenizer.pad_token = tokenizer.eos_token
15021512
texts = ["this is a simple input"]
@@ -1519,7 +1529,7 @@ def test_default_filling_attention_mask(self):
15191529

15201530
def test_default_filling_attention_mask_and_position_ids(self):
15211531
model_id = MODEL_NAMES["llama"]
1522-
model_with_cache = OVModelForCausalLM.from_pretrained(model_id, export=True, use_cache=True)
1532+
model_with_cache = OVModelForCausalLM.from_pretrained(model_id, stateful=False, use_cache=True)
15231533
tokenizer = AutoTokenizer.from_pretrained(model_id)
15241534
tokenizer.pad_token = tokenizer.eos_token
15251535
texts = ["this is a simple input"]

0 commit comments

Comments
 (0)