Skip to content

Commit 0447ae2

Browse files
add patching for update_causal_mask to falcon for >= 4.45 (#989)
* add patching for update_causal_mask to falcon and gpt-like models for >=4.45 * fix falcon * enable codegen2 back * Apply suggestions from code review Co-authored-by: Nikita Savelyev <[email protected]> * Update optimum/exporters/openvino/model_patcher.py --------- Co-authored-by: Nikita Savelyev <[email protected]>
1 parent 12783ee commit 0447ae2

File tree

3 files changed

+237
-16
lines changed

3 files changed

+237
-16
lines changed

optimum/exporters/openvino/model_configs.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
CodeGenOnnxConfig,
3030
FalconOnnxConfig,
3131
GemmaOnnxConfig,
32+
GPTJOnnxConfig,
3233
GPTNeoXOnnxConfig,
3334
IBertOnnxConfig,
3435
LlamaOnnxConfig,
@@ -66,6 +67,7 @@
6667
FalconModelPatcher,
6768
FluxTransfromerModelPatcher,
6869
Gemma2ModelPatcher,
70+
GptJModelPatcher,
6971
GptNeoxJapaneseModelPatcher,
7072
GptNeoxModelPatcher,
7173
IBertModelPatcher,
@@ -726,6 +728,24 @@ def patch_model_for_export(
726728
return GptNeoxJapaneseModelPatcher(self, model, model_kwargs=model_kwargs)
727729

728730

731+
@register_in_tasks_manager(
732+
"gptj",
733+
*[
734+
"feature-extraction",
735+
"feature-extraction-with-past",
736+
"text-generation",
737+
"text-generation-with-past",
738+
"text-classification",
739+
],
740+
library_name="transformers",
741+
)
742+
class GPTJOpenVINOConfig(GPTJOnnxConfig):
743+
def patch_model_for_export(
744+
self, model: Union["PreTrainedModel", "TFPreTrainedModel"], model_kwargs: Optional[Dict[str, Any]] = None
745+
) -> "ModelPatcher":
746+
return GptJModelPatcher(self, model, model_kwargs=model_kwargs)
747+
748+
729749
@register_in_tasks_manager(
730750
"cohere",
731751
*[

optimum/exporters/openvino/model_patcher.py

Lines changed: 216 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,20 @@ def patch_model_with_bettertransformer(model):
109109
return model
110110

111111

112-
def patch_update_causal_mask(model, transformers_version):
112+
def patch_update_causal_mask(model, transformers_version, inner_model_name="model", patch_fn=None):
113113
if is_transformers_version(">=", transformers_version):
114-
inner_model = getattr(model, "model", getattr(model, "transformer", None))
114+
inner_model = getattr(model, inner_model_name, None)
115115
if inner_model is not None:
116-
inner_model._update_causal_mask = types.MethodType(_llama_gemma_update_causal_mask, inner_model)
116+
if hasattr(inner_model, "_update_causal_mask"):
117+
inner_model._orig_update_causal_mask = inner_model._update_causal_mask
118+
patch_fn = patch_fn or _llama_gemma_update_causal_mask
119+
inner_model._update_causal_mask = types.MethodType(patch_fn, inner_model)
120+
121+
122+
def unpatch_update_causal_mask(model, inner_model_name="model"):
123+
inner_model = getattr(model, inner_model_name, None)
124+
if inner_model is not None and hasattr(inner_model, "._orig_update_causal_mask"):
125+
inner_model._update_causal_mask = inner_model._orig_update_causal_mask
117126

118127

119128
# initialization of sin/cos cached in bf16/fp16 leads to accuracy loss
@@ -579,13 +588,11 @@ def __enter__(self):
579588

580589
# llama/gemma has some accuracy issues with bf16 with transformers >= 4.39
581590
# fill causal mask in slightly different way for avoid overflow on some platforms
582-
patch_update_causal_mask(self._model, "4.39.0")
591+
patch_update_causal_mask(self._model, "4.39.0", "model" if hasattr(self._model, "model") else "transformer")
583592

584593
def __exit__(self, exc_type, exc_value, traceback):
585594
super().__exit__(exc_type, exc_value, traceback)
586-
inner_model = getattr(self._model, "model", getattr(self._model, "transformer", None))
587-
if hasattr(inner_model, "_orig_update_causal_mask"):
588-
inner_model._update_causal_mask = inner_model._orig_update_causal_mask
595+
unpatch_update_causal_mask(self._model, "model" if hasattr(self._model, "model") else "transformer")
589596

590597

591598
# copied from https://github.com/huggingface/transformers/commit/57d7594a79a9f5d835abf2d4d384db0e4818e548 to unblock export with transformers 4.42
@@ -1865,6 +1872,67 @@ def __exit__(self, exc_type, exc_value, traceback):
18651872
layer.self_attn.forward = layer.self_attn._orig_forward
18661873

18671874

1875+
# copied from https://github.com/huggingface/optimum/blob/2112e99122d7f23a1da1a9d263fef64301050ea7/optimum/bettertransformer/models/attention.py#L168
1876+
# for preserving backward compatibility between outdated codegen remote code and new transformers
1877+
def _codegen_wrapped_scaled_dot_product_legacy(
1878+
self,
1879+
query: torch.Tensor,
1880+
key: torch.Tensor,
1881+
value: torch.Tensor,
1882+
attention_mask: Optional[torch.Tensor] = None,
1883+
head_mask: Optional[torch.Tensor] = None,
1884+
):
1885+
from optimum.bettertransformer.models.attention import raise_on_head_mask
1886+
1887+
raise_on_head_mask(head_mask)
1888+
batch_size = query.shape[0]
1889+
mask_value = torch.finfo(value.dtype).min
1890+
mask_value = torch.full([], mask_value, dtype=value.dtype)
1891+
1892+
if batch_size == 1 and attention_mask is not None and attention_mask[0, 0, -1, -1] < -1:
1893+
raise ValueError("BetterTransformer does not support padding='max_length' with a batch size of 1.")
1894+
1895+
# in codegen the query and key are always in fp32 regardless of the dtype of the model
1896+
# https://github.com/huggingface/transformers/blob/5b28b7833297adf65c5160a685425ddb1eee5ce2/src/transformers/models/codegen/modeling_codegen.py#L226
1897+
query = query.to(value.dtype)
1898+
key = key.to(value.dtype)
1899+
1900+
dropout_p = self.dropout_prob_attn if self.training else 0.0
1901+
if batch_size == 1 or self.training:
1902+
if query.shape[2] > 1:
1903+
# first step of the decoding
1904+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
1905+
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=True
1906+
)
1907+
else:
1908+
# in this case, which is the later decoding steps, the `causal_mask`` in
1909+
# https://github.com/huggingface/transformers/blob/ae54e3c3b18bac0832ad62ea9b896dfd52a09850/src/transformers/models/gpt2/modeling_gpt2.py#L195
1910+
# is [True, ..., True] so actually not causal
1911+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
1912+
query, key, value, attn_mask=None, dropout_p=dropout_p, is_causal=False
1913+
)
1914+
else:
1915+
query_length, key_length = query.size(-2), key.size(-2)
1916+
1917+
# causal_mask is always [True, ..., True] otherwise, so executing this is unnecessary
1918+
if query_length > 1:
1919+
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.bool)
1920+
1921+
causal_mask = torch.where(causal_mask, 0, mask_value)
1922+
1923+
# torch.Tensor.expand does no memory copy
1924+
causal_mask = causal_mask.expand(batch_size, -1, -1, -1)
1925+
1926+
# we use torch.min to avoid having tensor(-inf)
1927+
attention_mask = torch.min(causal_mask, attention_mask)
1928+
1929+
sdpa_result = torch.nn.functional.scaled_dot_product_attention(
1930+
query, key, value, attn_mask=attention_mask, dropout_p=dropout_p, is_causal=False
1931+
)
1932+
1933+
return sdpa_result, None
1934+
1935+
18681936
class CodeGenModelPatcher(DecoderModelPatcher):
18691937
def __enter__(self):
18701938
super().__enter__()
@@ -1873,14 +1941,23 @@ def __enter__(self):
18731941
# For avoiding breaking model on tracing stage, we reduce area of bettertransformer patch only for _attn.
18741942
from optimum.bettertransformer.models.attention import codegen_wrapped_scaled_dot_product
18751943

1944+
attn_fn = codegen_wrapped_scaled_dot_product
1945+
if is_torch_version(">=", "2.1.0") and is_transformers_version(">=", "4.45"):
1946+
# in transformers 4.45 causal_mask const buffer was removed from the model
1947+
# if it still exists, it means legacy remote code was loaded
1948+
if hasattr(self._model.transformer.h[0].attn, "causal_mask"):
1949+
attn_fn = _codegen_wrapped_scaled_dot_product_legacy
1950+
18761951
for layer in self._model.transformer.h:
18771952
if is_torch_version(">=", "2.1.0") and not self._model.config.output_attentions:
18781953
orig_self_attn_fwd = layer.attn._attn
1879-
layer.attn._attn = types.MethodType(codegen_wrapped_scaled_dot_product, layer.attn)
1954+
layer.attn._attn = types.MethodType(attn_fn, layer.attn)
18801955
layer.attn._orig_attn = orig_self_attn_fwd
1956+
patch_update_causal_mask(self._model, "4.45.0", "transformer")
18811957

18821958
def __exit__(self, exc_type, exc_value, traceback):
18831959
super().__exit__(exc_type, exc_value, traceback)
1960+
unpatch_update_causal_mask(self._model, "transformer")
18841961
for layer in self._model.transformer.h:
18851962
if hasattr(layer.attn, "_orig_attn"):
18861963
layer.attn._attn = layer.attn._orig_attn
@@ -2275,8 +2352,7 @@ def __enter__(self):
22752352

22762353
def __exit__(self, exc_type, exc_value, traceback):
22772354
super().__exit__(exc_type, exc_value, traceback)
2278-
if hasattr(self._model.model, "_orig_update_causal_mask"):
2279-
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
2355+
unpatch_update_causal_mask(self._model)
22802356
for layer in self._model.model.layers:
22812357
if hasattr(layer.self_attn, "_orig_forward"):
22822358
layer.self_attn.forward = layer.self_attn._orig_forward
@@ -2413,8 +2489,7 @@ def __enter__(self):
24132489

24142490
def __exit__(self, exc_type, exc_value, traceback):
24152491
super().__exit__(exc_type, exc_value, traceback)
2416-
if hasattr(self._model.model, "_orig_update_causal_mask"):
2417-
self._model.model._update_causal_mask = self._model.model._orig_update_causal_mask
2492+
unpatch_update_causal_mask(self._model)
24182493

24192494

24202495
class RotaryEmbPatcher(DecoderModelPatcher):
@@ -2425,12 +2500,119 @@ def __enter__(self):
24252500
_reinitialize_cos_sin_cached_fp32(layer.self_attn.rotary_emb)
24262501

24272502

2503+
def _falcon_update_causal_mask(
2504+
self,
2505+
attention_mask: torch.Tensor,
2506+
input_tensor: torch.Tensor,
2507+
cache_position: torch.Tensor,
2508+
past_key_values: "Cache",
2509+
output_attentions: bool,
2510+
head_mask: torch.Tensor,
2511+
alibi: torch.Tensor,
2512+
):
2513+
# copied from https://github.com/huggingface/transformers/blob/a30c865f991dfec9452cc64bd9a97bfbb96be036/src/transformers/models/falcon/modeling_falcon.py#L1130
2514+
from transformers.cache_utils import StaticCache
2515+
from transformers.modeling_attn_mask_utils import AttentionMaskConverter
2516+
2517+
# TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
2518+
# KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
2519+
# (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
2520+
# `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114
2521+
2522+
if hasattr(self, "_prepare_4d_causal_attention_mask_with_cache_position"):
2523+
_prepare_4d_causal_attention_mask_with_cache_position = (
2524+
self._prepare_4d_causal_attention_mask_with_cache_position
2525+
)
2526+
else:
2527+
from transformers.models.falcon.modeling_falcon import _prepare_4d_causal_attention_mask_with_cache_position
2528+
2529+
if self.config._attn_implementation == "flash_attention_2":
2530+
if attention_mask is not None and 0.0 in attention_mask:
2531+
return attention_mask
2532+
return None
2533+
2534+
# For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
2535+
# order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
2536+
# to infer the attention mask.
2537+
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
2538+
using_static_cache = isinstance(past_key_values, StaticCache)
2539+
2540+
# When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
2541+
if (
2542+
self.config._attn_implementation == "sdpa"
2543+
and not using_static_cache
2544+
and not output_attentions
2545+
and head_mask is None
2546+
and alibi is None
2547+
):
2548+
if AttentionMaskConverter._ignore_causal_mask_sdpa(
2549+
attention_mask,
2550+
inputs_embeds=input_tensor,
2551+
past_key_values_length=past_seen_tokens,
2552+
is_training=self.training,
2553+
):
2554+
return None
2555+
2556+
dtype, device = input_tensor.dtype, input_tensor.device
2557+
# difference from original, replace torch.finfo(dtype).min to float16 for prevent overflow for fp16/bf16 execution
2558+
min_dtype = torch.finfo(torch.float16).min
2559+
batch_size, sequence_length, _ = input_tensor.shape
2560+
if using_static_cache:
2561+
target_length = past_key_values.get_max_length()
2562+
else:
2563+
target_length = (
2564+
attention_mask.shape[-1]
2565+
if isinstance(attention_mask, torch.Tensor)
2566+
else past_seen_tokens + sequence_length
2567+
)
2568+
2569+
# In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
2570+
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position(
2571+
attention_mask,
2572+
sequence_length=sequence_length,
2573+
target_length=target_length,
2574+
dtype=dtype,
2575+
device=device,
2576+
min_dtype=min_dtype,
2577+
cache_position=cache_position,
2578+
batch_size=input_tensor.shape[0],
2579+
)
2580+
2581+
# We take care to integrate alibi bias in the causal_mask here
2582+
if head_mask is None and alibi is not None:
2583+
alibi = alibi.reshape(batch_size, -1, *alibi.shape[1:])
2584+
causal_mask = torch.masked_fill(
2585+
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
2586+
causal_mask < -1,
2587+
min_dtype,
2588+
)
2589+
2590+
if (
2591+
self.config._attn_implementation == "sdpa"
2592+
and attention_mask is not None
2593+
and attention_mask.device.type == "cuda"
2594+
and not output_attentions
2595+
):
2596+
# Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
2597+
# using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
2598+
# Details: https://github.com/pytorch/pytorch/issues/110213
2599+
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
2600+
2601+
return causal_mask
2602+
2603+
24282604
class FalconModelPatcher(DecoderModelPatcher):
24292605
def __enter__(self):
24302606
super().__enter__()
24312607
if is_transformers_version("<", "4.44.99"):
24322608
for layer in self._model.transformer.h:
24332609
_reinitialize_cos_sin_cached_fp32(layer.self_attention.rotary_emb)
2610+
else:
2611+
patch_update_causal_mask(self._model, "4.45.0", "transformer", _falcon_update_causal_mask)
2612+
2613+
def __exit__(self, exc_type, exc_value, traceback):
2614+
super().__exit__(exc_type, exc_value, traceback)
2615+
unpatch_update_causal_mask(self._model, "transformer")
24342616

24352617

24362618
class GptNeoxModelPatcher(DecoderModelPatcher):
@@ -2439,6 +2621,22 @@ def __enter__(self):
24392621
if is_transformers_version("<", "4.44.99"):
24402622
for layer in self._model.gpt_neox.layers:
24412623
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
2624+
else:
2625+
patch_update_causal_mask(self._model, "4.45.0", "gpt_neox")
2626+
2627+
def __exit__(self, exc_type, exc_value, traceback):
2628+
super().__exit__(exc_type, exc_value, traceback)
2629+
unpatch_update_causal_mask(self._model, "gpt_neox")
2630+
2631+
2632+
class GptJModelPatcher(DecoderModelPatcher):
2633+
def __enter__(self):
2634+
super().__enter__()
2635+
patch_update_causal_mask(self._model, "4.45.0", "transformer")
2636+
2637+
def __exit__(self, exc_type, exc_value, traceback):
2638+
super().__exit__(exc_type, exc_value, traceback)
2639+
unpatch_update_causal_mask(self._model, "transformer")
24422640

24432641

24442642
class GptNeoxJapaneseModelPatcher(DecoderModelPatcher):
@@ -2447,6 +2645,12 @@ def __enter__(self):
24472645
if is_transformers_version("<", "4.44.99"):
24482646
for layer in self._model.gpt_neox_japanese.layers:
24492647
_reinitialize_cos_sin_cached_fp32(layer.attention.rotary_emb)
2648+
else:
2649+
patch_update_causal_mask(self._model, "4.45.0", "gpt_neox_japanese")
2650+
2651+
def __exit__(self, exc_type, exc_value, traceback):
2652+
super().__exit__(exc_type, exc_value, traceback)
2653+
unpatch_update_causal_mask(self._model, "gpt_neox_japanese")
24502654

24512655

24522656
class Gemma2ModelPatcher(LlamaModelPatcher):

tests/openvino/test_modeling.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -773,6 +773,7 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
773773
"bloom",
774774
"chatglm",
775775
"codegen",
776+
"codegen2",
776777
"gpt2",
777778
"gpt_neo",
778779
"gpt_neox",
@@ -821,10 +822,6 @@ class OVModelForCausalLMIntegrationTest(unittest.TestCase):
821822
"mistral-nemo",
822823
)
823824

824-
# custom modeling defined in https://huggingface.co/katuni4ka/tiny-random-codegen2 differs from transformers after v4.45 resulting in unadapted patching
825-
if is_transformers_version("<", "4.45.0"):
826-
SUPPORTED_ARCHITECTURES += ("codegen2",)
827-
828825
GENERATION_LENGTH = 100
829826
REMOTE_CODE_MODELS = (
830827
"chatglm",

0 commit comments

Comments
 (0)