Skip to content

Commit cb2f2ec

Browse files
eaidovaecharlaix
andauthored
Update llama and gemma patching for resolving bf16 execution issues with sin/cos (#783)
* update llama and gemma patching for resolving bf16 execution issues * fix model patcher * update tests * Update tests/openvino/test_modeling.py Co-authored-by: Ella Charlaix <[email protected]> * Update optimum/exporters/openvino/model_patcher.py Co-authored-by: Ella Charlaix <[email protected]> * apply review comments * Update optimum/exporters/openvino/model_patcher.py * format * Update optimum/exporters/openvino/model_patcher.py * fix failing test * Update tests/openvino/test_exporters_cli.py Co-authored-by: Ella Charlaix <[email protected]> --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent 629a2e3 commit cb2f2ec

File tree

5 files changed

+56
-33
lines changed

5 files changed

+56
-33
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949
ChatGLMModelPatcher,
5050
CodeGenModelPatcher,
5151
DBRXModelPatcher,
52-
GemmaModelPatcher,
5352
InternLM2Patcher,
5453
InternLMModelPatcher,
5554
JaisModelPatcher,
@@ -319,7 +318,7 @@ class GemmaOpenVINOConfig(GemmaOnnxConfig):
319318
def patch_model_for_export(
320319
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
321320
) -> "ModelPatcher":
322-
return GemmaModelPatcher(self, model, model_kwargs=model_kwargs)
321+
return LlamaModelPatcher(self, model, model_kwargs=model_kwargs)
323322

324323

325324
@register_in_tasks_manager(

optimum/exporters/openvino/model_patcher.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -497,50 +497,65 @@ def _llama_gemma_update_causal_mask_latest(
497497
_llama_gemma_update_causal_mask = _llama_gemma_update_causal_mask_legacy
498498

499499

500-
class GemmaModelPatcher(DecoderModelPatcher):
500+
def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None):
501+
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/gemma/modeling_gemma.py#L104
502+
_seq_len = torch.max(position_ids) + 1 if seq_len is None else seq_len
503+
if _seq_len > self.embed_positions.shape[0]:
504+
if seq_len is None:
505+
return self._orig_forward(x, position_ids)
506+
else:
507+
return self._orig_forward(x, position_ids, seq_len)
508+
sincos = self.embed_positions[position_ids]
509+
sin, cos = torch.split(sincos, sincos.shape[-1] // 2, dim=-1)
510+
return cos, sin
511+
512+
513+
class LlamaModelPatcher(DecoderModelPatcher):
501514
def __enter__(self):
502515
super().__enter__()
503516

504-
# gemma has some accuracy issues with bf16 with transformers >= 4.39
517+
# llama/gemma has some accuracy issues with bf16 with transformers >= 4.39
505518
# fill causal mask in slightly different way for avoid overflow on some platforms
506519
if is_transformers_version(">=", "4.39.0"):
507520
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
508521
self._model.model._update_causal_mask = types.MethodType(
509522
_llama_gemma_update_causal_mask, self._model.model
510523
)
511524

512-
# init inv_freq for torchscript tracing
513-
# https://github.com/huggingface/transformers/blob/ed74d97871468f3a4695ede50abdc0b55717a84d/src/transformers/models/gemma/modeling_gemma.py#L108
514-
for layer in self._model.model.layers:
515-
if layer.self_attn.rotary_emb.inv_freq is None:
516-
rotary_emb = layer.self_attn.rotary_emb
517-
layer.self_attn.rotary_emb.inv_freq = 1.0 / (
518-
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
519-
)
525+
max_positions = self._model.config.max_position_embeddings
520526

521-
def __exit__(self, exc_type, exc_value, traceback):
522-
super().__exit__(exc_type, exc_value, traceback)
523-
if hasattr(self._model.model, "_orig_update_causal_mask"):
524-
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
527+
# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
528+
# use precomputed
529+
def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000) -> torch.Tensor:
530+
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
531+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
525532

533+
sinusoid_inp = torch.einsum(
534+
"i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq
535+
).float()
536+
emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1)
537+
return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
526538

527-
class LlamaModelPatcher(DecoderModelPatcher):
528-
def __enter__(self):
529-
super().__enter__()
539+
base = self._model.model.layers[0].self_attn.rotary_emb.base
540+
dim = self._model.model.layers[0].self_attn.rotary_emb.dim
541+
embed_positions = create_sinusoidal_positions(max_positions, dim, base)
530542

531-
# llama has some accuracy issues with bf16 with transformers >= 4.39
532-
# fill causal mask in slightly different way for avoid overflow on some platforms
533-
if is_transformers_version(">=", "4.39.0"):
534-
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
535-
self._model.model._update_causal_mask = types.MethodType(
536-
_llama_gemma_update_causal_mask, self._model.model
537-
)
543+
for layer in self._model.model.layers:
544+
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
545+
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward
546+
547+
layer.self_attn.rotary_emb.forward = types.MethodType(
548+
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
549+
)
538550

539551
def __exit__(self, exc_type, exc_value, traceback):
540552
super().__exit__(exc_type, exc_value, traceback)
541553
if hasattr(self._model.model, "_orig_update_causal_mask"):
542554
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
543555

556+
for layer in self._model.model.layers:
557+
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
558+
544559

545560
SUPPORT_SDPA = is_torch_version(">", "2.1.0")
546561

tests/openvino/test_exporters_cli.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
)
4242
from optimum.intel.openvino.configuration import _DEFAULT_4BIT_CONFIGS
4343
from optimum.intel.openvino.utils import _HEAD_TO_AUTOMODELS
44-
from optimum.intel.utils.import_utils import is_openvino_tokenizers_available
44+
from optimum.intel.utils.import_utils import is_openvino_tokenizers_available, is_transformers_version
4545

4646

4747
class OVCLIExportTestCase(unittest.TestCase):
@@ -90,20 +90,26 @@ class OVCLIExportTestCase(unittest.TestCase):
9090
("text-generation-with-past", "opt125m", "int4_asym_g128", 4, 144),
9191
("text-generation-with-past", "opt125m", "int4_sym_g64", 4, 144),
9292
("text-generation-with-past", "opt125m", "int4_asym_g64", 4, 144),
93-
("text-generation-with-past", "llama_awq", "int4 --ratio 1.0 --sym --group-size 16 --all-layers", 0, 32),
93+
(
94+
"text-generation-with-past",
95+
"llama_awq",
96+
"int4 --ratio 1.0 --sym --group-size 8 --all-layers",
97+
0,
98+
32 if is_transformers_version("<", "4.39.0") else 34,
99+
),
94100
(
95101
"text-generation-with-past",
96102
"llama_awq",
97103
"int4 --ratio 1.0 --sym --group-size 16 --awq --dataset wikitext2 --num-samples 100 "
98104
"--sensitivity-metric max_activation_variance",
99-
4,
105+
6 if is_transformers_version(">=", "4.39") else 4,
100106
28,
101107
),
102108
(
103109
"text-generation-with-past",
104110
"llama_awq",
105111
"int4 --ratio 1.0 --sym --group-size 16 --scale-estimation --dataset wikitext2 --num-samples 100 ",
106-
4,
112+
6 if is_transformers_version(">=", "4.39") else 4,
107113
28,
108114
),
109115
]

tests/openvino/test_modeling.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -932,15 +932,18 @@ def test_beam_search(self, model_arch):
932932
do_sample=False,
933933
eos_token_id=None,
934934
)
935+
935936
beam_sample_gen_config = GenerationConfig(
936937
max_new_tokens=10,
937938
min_new_tokens=10,
938939
num_beams=4,
939940
do_sample=True,
940941
eos_token_id=None,
941-
top_k=1,
942942
)
943943

944+
if model_arch == "minicpm":
945+
beam_sample_gen_config.top_k = 1
946+
944947
group_beam_search_gen_config = GenerationConfig(
945948
max_new_tokens=10,
946949
min_new_tokens=10,

tests/openvino/test_quantization.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ class OVWeightCompressionTest(unittest.TestCase):
230230
quant_method=QuantizationMethod.AWQ,
231231
scale_estimation=True,
232232
),
233-
16,
233+
18 if is_transformers_version(">=", "4.39") else 16,
234234
),
235235
(
236236
OVModelForCausalLM,
@@ -244,7 +244,7 @@ class OVWeightCompressionTest(unittest.TestCase):
244244
dataset="c4",
245245
quant_method="awq",
246246
),
247-
16,
247+
18 if is_transformers_version(">=", "4.39") else 16,
248248
),
249249
)
250250

0 commit comments

Comments
 (0)