Skip to content

Commit 31430d6

Browse files
eaidovaecharlaix
andauthored
Fix bf16 inference accuracy for mistral, phi3, dbrx (#833)
* Fix bf16 inference accuracy for mistral, phi3, dbrx * reuse inv_freq * Apply suggestions from code review Co-authored-by: Ella Charlaix <[email protected]> * make dim and base optional * fix model patcher for dbrx and add bitwise fix for mistral --------- Co-authored-by: Ella Charlaix <[email protected]>
1 parent 819c513 commit 31430d6

File tree

2 files changed

+207
-20
lines changed

2 files changed

+207
-20
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
FalconOnnxConfig,
2525
GemmaOnnxConfig,
2626
LlamaOnnxConfig,
27+
MistralOnnxConfig,
2728
MPTOnnxConfig,
2829
PhiOnnxConfig,
2930
UNetOnnxConfig,
@@ -53,6 +54,7 @@
5354
InternLMModelPatcher,
5455
JaisModelPatcher,
5556
LlamaModelPatcher,
57+
MistralModelPatcher,
5658
MixtralModelPatcher,
5759
MPTModelPatcher,
5860
PersimmonModelPatcher,
@@ -839,3 +841,21 @@ def patch_model_for_export(
839841
)
840842

841843
return ArcticModelPatcher(self, model, model_kwargs=model_kwargs)
844+
845+
846+
@register_in_tasks_manager(
847+
"mistral",
848+
*[
849+
"feature-extraction",
850+
"feature-extraction-with-past",
851+
"text-generation",
852+
"text-generation-with-past",
853+
"text-classification",
854+
],
855+
library_name="transformers",
856+
)
857+
class MistralOpenVINOConfig(MistralOnnxConfig):
858+
def patch_model_for_export(
859+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
860+
) -> "ModelPatcher":
861+
return MistralModelPatcher(self, model, model_kwargs=model_kwargs)

optimum/exporters/openvino/model_patcher.py

Lines changed: 187 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -510,6 +510,39 @@ def llama_gemma_rotary_emb_forward(self, x, position_ids, seq_len=None):
510510
return cos, sin
511511

512512

513+
def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000, inv_freq=None) -> torch.Tensor:
514+
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
515+
if inv_freq is None:
516+
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
517+
518+
sinusoid_inp = torch.einsum("i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq).float()
519+
emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1)
520+
return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
521+
522+
523+
def register_sin_cos_buffer(model):
524+
max_positions = model.config.max_position_embeddings
525+
526+
# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
527+
# use precomputed
528+
529+
rotary_emb = model.model.layers[0].self_attn.rotary_emb
530+
dim, base = None, None
531+
inv_freq = getattr(rotary_emb, "inv_freq", None)
532+
if inv_freq is None:
533+
base = rotary_emb.base
534+
dim = rotary_emb.dim
535+
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)
536+
537+
for layer in model.model.layers:
538+
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
539+
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward
540+
541+
layer.self_attn.rotary_emb.forward = types.MethodType(
542+
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
543+
)
544+
545+
513546
class LlamaModelPatcher(DecoderModelPatcher):
514547
def __enter__(self):
515548
super().__enter__()
@@ -521,39 +554,148 @@ def __enter__(self):
521554
self._model.model._update_causal_mask = types.MethodType(
522555
_llama_gemma_update_causal_mask, self._model.model
523556
)
557+
register_sin_cos_buffer(self._model)
524558

525-
max_positions = self._model.config.max_position_embeddings
559+
def __exit__(self, exc_type, exc_value, traceback):
560+
super().__exit__(exc_type, exc_value, traceback)
561+
if hasattr(self._model.model, "_orig_update_causal_mask"):
562+
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
526563

527-
# cos/sin for rotary position embeddings also having issues with bf16 and efficiency due to calculation on each step
528-
# use precomputed
529-
def create_sinusoidal_positions(num_pos: int, dim: int, base: int = 10000) -> torch.Tensor:
530-
# adopted from https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py#L101
531-
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.int64) / dim))
564+
for layer in self._model.model.layers:
565+
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
532566

533-
sinusoid_inp = torch.einsum(
534-
"i , j -> i j", torch.arange(num_pos, dtype=torch.int64).float(), inv_freq
535-
).float()
536-
emb = torch.cat((sinusoid_inp, sinusoid_inp), dim=-1)
537-
return torch.cat((torch.sin(emb), torch.cos(emb)), dim=1)
538567

539-
base = self._model.model.layers[0].self_attn.rotary_emb.base
540-
dim = self._model.model.layers[0].self_attn.rotary_emb.dim
541-
embed_positions = create_sinusoidal_positions(max_positions, dim, base)
568+
# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
569+
def _mistral_update_causal_mask(
570+
self,
571+
attention_mask: torch.Tensor,
572+
input_tensor: torch.Tensor,
573+
cache_position: torch.Tensor,
574+
past_key_values: "Cache",
575+
use_cache: bool,
576+
output_attentions: bool,
577+
):
578+
from transformers.cache_utils import SlidingWindowCache, StaticCache
579+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
542580

543-
for layer in self._model.model.layers:
544-
layer.self_attn.rotary_emb.register_buffer("embed_positions", embed_positions)
545-
layer.self_attn.rotary_emb._orig_forward = layer.self_attn.rotary_emb.forward
581+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
582+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
583+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
584+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
585+
586+
if self._attn_implementation == "flash_attention_2":
587+
if attention_mask is not None and use_cache:
588+
is_padding_right = attention_mask[:, -1].sum().item() != input_tensor.size()[0]
589+
if is_padding_right:
590+
raise ValueError(
591+
"You are attempting to perform batched generation with padding_side='right'"
592+
" this may lead to unexpected behaviour for Flash Attention version of Mistral. Make sure to "
593+
" call `tokenizer.padding_side = 'left'` before tokenizing the input. "
594+
)
595+
if attention_mask is not None and 0.0 in attention_mask:
596+
return attention_mask
597+
return None
598+
599+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
600+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
601+
# to infer the attention mask.
602+
603+
# cache_position must be valid here no matter which cache we use
604+
past_seen_tokens = cache_position[0] if past_key_values is not None else 0
605+
using_static_cache = isinstance(past_key_values, StaticCache)
606+
using_sliding_window_cache = isinstance(past_key_values, SlidingWindowCache)
607+
608+
if (
609+
self.config._attn_implementation == "sdpa"
610+
and not (using_static_cache or using_sliding_window_cache)
611+
and not output_attentions
612+
):
613+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
614+
attention_mask,
615+
inputs_embeds=input_tensor,
616+
past_key_values_length=past_seen_tokens,
617+
sliding_window=self.config.sliding_window,
618+
is_training=self.training,
619+
):
620+
return None
546621

547-
layer.self_attn.rotary_emb.forward = types.MethodType(
548-
llama_gemma_rotary_emb_forward, layer.self_attn.rotary_emb
622+
dtype, device = input_tensor.dtype, input_tensor.device
623+
min_dtype = torch.finfo(dtype).min
624+
sequence_length = input_tensor.shape[1]
625+
# SlidingWindowCache
626+
if using_sliding_window_cache:
627+
target_length = max(sequence_length, self.config.sliding_window)
628+
# StaticCache
629+
elif using_static_cache:
630+
target_length = past_key_values.get_max_length()
631+
# DynamicCache or no cache
632+
else:
633+
target_length = (
634+
attention_mask.shape[-1]
635+
if isinstance(attention_mask, torch.Tensor)
636+
else past_seen_tokens + sequence_length + 1
637+
)
638+
639+
if attention_mask is not None and attention_mask.dim() == 4:
640+
# in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
641+
if attention_mask.max() != 0:
642+
raise ValueError("Custom 4D attention mask should be passed in inverted form with max==0`")
643+
causal_mask = attention_mask
644+
else:
645+
causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
646+
exclude_mask = torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
647+
if self.config.sliding_window is not None:
648+
if not using_sliding_window_cache or sequence_length > self.config.sliding_window:
649+
exclude_mask = exclude_mask.bitwise_or(
650+
torch.arange(target_length, device=device)
651+
<= (cache_position.reshape(-1, 1) - self.config.sliding_window)
652+
)
653+
causal_mask *= exclude_mask
654+
causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
655+
if attention_mask is not None:
656+
causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
657+
if attention_mask.dim() == 2:
658+
mask_length = attention_mask.shape[-1]
659+
padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
660+
padding_mask = padding_mask == 0
661+
causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
662+
padding_mask, min_dtype
549663
)
550664

665+
if (
666+
self.config._attn_implementation == "sdpa"
667+
and attention_mask is not None
668+
and attention_mask.device.type == "cuda"
669+
and not output_attentions
670+
):
671+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
672+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
673+
# Details: https://github.com/pytorch/pytorch/issues/110213
674+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
675+
676+
return causal_mask
677+
678+
679+
class MistralModelPatcher(DecoderModelPatcher):
680+
def __enter__(self):
681+
super().__enter__()
682+
if is_transformers_version(">=", "4.42.0"):
683+
# apply fix https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548
684+
self._model.model._orig_update_causal_mask = self._model.model._update_causal_mask
685+
self._model.model._update_causal_mask = types.MethodType(_mistral_update_causal_mask, self._model.model)
686+
687+
# mistral has some accuracy issues with bf16 with transformers >= 4.42
688+
# prefill rotary emb sin/cos for avoid this issue
689+
register_sin_cos_buffer(self._model)
690+
551691
def __exit__(self, exc_type, exc_value, traceback):
552692
super().__exit__(exc_type, exc_value, traceback)
693+
553694
if hasattr(self._model.model, "_orig_update_causal_mask"):
554695
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
555696

556-
for layer in self._model.model.layers:
697+
for layer in self._model.model.layers:
698+
if hasattr(layer.self_attn.rotary_emb, "_orig_forward"):
557699
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
558700

559701

@@ -1283,11 +1425,15 @@ def __enter__(self):
12831425
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
12841426
)
12851427

1428+
# phi3 has issue with bf16 inference, precollect sin/cos for rotary_position_embedding for avoid accuracy issues
1429+
register_sin_cos_buffer(self._model)
1430+
12861431
def __exit__(self, exc_type, exc_value, traceback):
12871432
super().__exit__(exc_type, exc_value, traceback)
12881433
for layer in self._model.model.layers:
12891434
if hasattr(layer.self_attn, "_orig_forward"):
12901435
layer.self_attn.forward = layer.self_attn._orig_forward
1436+
layer.self_attn.rotary_emb.forward = layer.self_attn.rotary_emb._orig_forward
12911437

12921438

12931439
def _aquila_self_attn_sdpa_forward(
@@ -1807,6 +1953,18 @@ def __enter__(self):
18071953
_dbrx_update_causal_mask, self._model.transformer
18081954
)
18091955

1956+
# starting from transformers 4.41 issue also observable for calculation sin/cos for rotary_emb
1957+
patch_rope_sin_cos = is_transformers_version(">=", "4.41.0")
1958+
1959+
inv_freq = getattr(self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb, "inv_freq")
1960+
dim, base = None, None
1961+
if inv_freq is None:
1962+
dim = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.dim
1963+
base = self._model.transformer.blocks[0].norm_attn_norm.attn.rotary_emb.base
1964+
max_positions = self._model.config.max_seq_len
1965+
if patch_rope_sin_cos:
1966+
embed_positions = create_sinusoidal_positions(max_positions, dim, base, inv_freq)
1967+
18101968
for block in self._model.transformer.blocks:
18111969
rotary_emb = block.norm_attn_norm.attn.rotary_emb
18121970
# initialize inv_freq for torchscript tracing
@@ -1815,6 +1973,12 @@ def __enter__(self):
18151973
rotary_emb.base ** (torch.arange(0, rotary_emb.dim, 2, dtype=torch.int64).float() / rotary_emb.dim)
18161974
)
18171975
rotary_emb.inv_freq = inv_freq
1976+
1977+
if patch_rope_sin_cos:
1978+
rotary_emb.register_buffer("embed_positions", embed_positions)
1979+
rotary_emb._orig_forward = rotary_emb.forward
1980+
rotary_emb.forward = types.MethodType(llama_gemma_rotary_emb_forward, rotary_emb)
1981+
18181982
# remove continue-operator from iteration loop over experts
18191983
block.ffn.experts._orig_forward = block.ffn.experts.forward
18201984
block.ffn.experts.forward = types.MethodType(_dbrx_experts_forward, block.ffn.experts)
@@ -1825,6 +1989,9 @@ def __exit__(self, exc_type, exc_value, traceback):
18251989
for block in self._model.transformer.blocks:
18261990
block.ffn.experts.forward = block.ffn.experts._orig_forward
18271991

1992+
if hasattr(block.norm_attn_norm.attn.rotary_emb, "_orig_forward"):
1993+
block.norm_attn_norm.attn.rotary_emb.forward = block.norm_attn_norm.attn.rotary_emb._orig_forward
1994+
18281995

18291996
# Adapted from https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/models/persimmon/modeling_persimmon.py#L264
18301997
def _persimmon_self_attn_sdpa_forward(

0 commit comments

Comments
 (0)