Skip to content

Commit 6b993b8

Browse files
Fix ipex mistral export for bs > 1 (#1276)
* fix bug when bs > 1 and do not provide `position_ids` for input Signed-off-by: Liu, Kaixuan <[email protected]> * Update tests/ipex/test_modeling_causal_lm.py Co-authored-by: Ella Charlaix <[email protected]> * delete code for test case Signed-off-by: Liu, Kaixuan <[email protected]> --------- Signed-off-by: Liu, Kaixuan <[email protected]> Co-authored-by: Ella Charlaix <[email protected]>
1 parent 264205b commit 6b993b8

File tree

2 files changed

+34
-27
lines changed

2 files changed

+34
-27
lines changed

optimum/exporters/ipex/modeling_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,10 @@ def _mistral_model_forward(
665665
)
666666

667667
if position_ids is None:
668-
position_ids = cache_position.unsqueeze(0)
668+
position_ids = torch.arange(
669+
past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
670+
)
671+
position_ids = position_ids.unsqueeze(0).repeat_interleave(input_ids.shape[0], 0)
669672

670673
causal_mask = self._update_causal_mask(
671674
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
@@ -1021,7 +1024,7 @@ def __init__(self, module, device, config) -> None:
10211024
self.module_device = device
10221025

10231026
if not config.compile and getattr(config, "quantization_config", None) is None:
1024-
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
1027+
# LinearAllreduce cannot use fused op LinearAdd
10251028
if module.down_proj.__class__.__name__ not in ["LinearAllreduce"]:
10261029
self.mlp_linear_add = LinearAdd(module.down_proj)
10271030
if isinstance(self.act_fn, nn.SiLU):
@@ -1049,7 +1052,7 @@ def __init__(self, module, device, config) -> None:
10491052
self.config = config
10501053
self.module_device = device
10511054
if not config.compile and getattr(config, "quantization_config", None) is None:
1052-
# LinearAllreduce and LinearLayer cannot use fused op LinearAdd
1055+
# LinearAllreduce cannot use fused op LinearAdd
10531056
self.linear_gelu = LinearGelu(module.dense_h_to_4h)
10541057

10551058
if module.dense_4h_to_h.__class__.__name__ not in ["LinearAllreduce"]:

tests/ipex/test_modeling_causal_lm.py

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -70,35 +70,39 @@ def test_compare_to_transformers(self, model_arch):
7070
ipex_model = IPEXModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
7171
self.assertIsInstance(ipex_model.config, PretrainedConfig)
7272
tokenizer = AutoTokenizer.from_pretrained(model_id)
73-
tokens = tokenizer(
74-
"This is a sample",
75-
return_tensors="pt",
76-
return_token_type_ids=False if model_arch in ("llama2",) else None,
77-
).to(DEVICE)
78-
inputs = ipex_model.prepare_inputs_for_generation(**tokens)
79-
outputs = ipex_model(**inputs)
73+
texts = ["This is a sample", ["This is the first input", "This is the second input"]]
74+
for text in texts:
75+
tokens = tokenizer(
76+
text,
77+
return_tensors="pt",
78+
return_token_type_ids=False if model_arch in ("llama2",) else None,
79+
).to(DEVICE)
80+
outputs = ipex_model(**tokens)
81+
inputs = ipex_model.prepare_inputs_for_generation(**tokens)
82+
outputs_2 = ipex_model(**inputs)
83+
self.assertTrue(torch.allclose(outputs.logits, outputs_2.logits, atol=1e-3))
8084

81-
self.assertIsInstance(outputs.logits, torch.Tensor)
85+
self.assertIsInstance(outputs.logits, torch.Tensor)
8286

83-
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
84-
with torch.no_grad():
85-
transformers_outputs = transformers_model(**tokens)
87+
transformers_model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=dtype, device_map=DEVICE)
88+
with torch.no_grad():
89+
transformers_outputs = transformers_model(**tokens)
8690

87-
# Test re-load model
88-
with tempfile.TemporaryDirectory() as tmpdirname:
89-
ipex_model.save_pretrained(tmpdirname)
90-
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, device_map=DEVICE)
91-
loaded_model_outputs = loaded_model(**inputs)
91+
# Test re-load model
92+
with tempfile.TemporaryDirectory() as tmpdirname:
93+
ipex_model.save_pretrained(tmpdirname)
94+
loaded_model = self.IPEX_MODEL_CLASS.from_pretrained(tmpdirname, torch_dtype=dtype, device_map=DEVICE)
95+
loaded_model_outputs = loaded_model(**inputs)
9296

93-
# Test init method
94-
init_model = self.IPEX_MODEL_CLASS(transformers_model)
95-
init_model_outputs = init_model(**inputs)
97+
# Test init method
98+
init_model = self.IPEX_MODEL_CLASS(transformers_model)
99+
init_model_outputs = init_model(**inputs)
96100

97-
# Compare tensor outputs
98-
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3))
99-
# To avoid float pointing error
100-
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
101-
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))
101+
# Compare tensor outputs
102+
self.assertTrue(torch.allclose(outputs.logits, transformers_outputs.logits, atol=1e-3))
103+
# To avoid float pointing error
104+
self.assertTrue(torch.allclose(outputs.logits, loaded_model_outputs.logits, atol=1e-7))
105+
self.assertTrue(torch.allclose(outputs.logits, init_model_outputs.logits, atol=1e-7))
102106

103107
@parameterized.expand(SUPPORTED_ARCHITECTURES)
104108
@unittest.skip(reason="Paged attention do not support assisted decoding for now")

0 commit comments

Comments
 (0)