Skip to content

Commit a80f30b

Browse files
Fix CE patch and add layernorm support for InternVL (#921)
## Summary Fix CE patch and add layernorm support for InternVL. Related issue: #920 ## Testing Done - Hardware Type: A100 - [ ] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence ``` (9 durations < 0.005s hidden. Use -vv to show these durations.) ================================================================================================================= short test summary info ================================================================================================================= FAILED test/convergence/fp32/test_mini_models_multimodal.py::test_mini_model_multimodal[mini_llava-32-0.0001-dtype1-1e-08-1e-05-0.005-1e-05-0.005-1e-05] - TypeError: 'NoneType' object is not subscriptable ================================================================================================== 1 failed, 7 passed, 11 warnings in 147.93s (0:02:27) =================================================================================================== sys:1: DeprecationWarning: builtin type swigvarlink has no __module__ attribute make: *** [Makefile:32: test-convergence] Error 1 ``` llava crash, can't guess why(maybe my transformers version matter) --------- Co-authored-by: Steven Shimizu <[email protected]>
1 parent 606ca4e commit a80f30b

File tree

3 files changed

+77
-34
lines changed

3 files changed

+77
-34
lines changed

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 52 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2038,6 +2038,7 @@ def apply_liger_kernel_to_internvl(
20382038
cross_entropy: bool = False,
20392039
fused_linear_cross_entropy: bool = True,
20402040
rms_norm: bool = True,
2041+
layer_norm: bool = True,
20412042
model: Optional[PreTrainedModel] = None,
20422043
**kwargs,
20432044
) -> None:
@@ -2048,37 +2049,60 @@ def apply_liger_kernel_to_internvl(
20482049
NOTE: InternVL is not available in transformers<4.52.1
20492050
20502051
Args:
2051-
rope (bool): Whether to apply Liger's rotary position embedding. Default is True.
20522052
cross_entropy (bool): Whether to apply Liger's cross entropy loss. Default is False.
20532053
fused_linear_cross_entropy (bool):
20542054
Whether to apply Liger's fused linear cross entropy loss. Default is True.
20552055
`cross_entropy` and `fused_linear_cross_entropy` cannot both be True.
20562056
If `fused_linear_cross_entropy` is True, the logits will not be materialized but more memory efficient.
20572057
rms_norm (bool): Whether to apply Liger's RMSNorm. Default is True.
2058-
swiglu (bool): Whether to apply Liger's SwiGLU MLP. Default is True.
2058+
layer_norm (bool): Whether to apply Liger's LayerNorm. Default is True.
20592059
model (PreTrainedModel): The model instance to apply Liger kernels to, if the model has already been
20602060
loaded. Default is None.
20612061
"""
20622062
assert not (cross_entropy and fused_linear_cross_entropy), (
20632063
"cross_entropy and fused_linear_cross_entropy cannot both be True."
20642064
)
2065+
import torch.nn as torch_nn
20652066

20662067
from transformers.models.internvl import modeling_internvl
2068+
from transformers.models.internvl.modeling_internvl import InternVLForConditionalGeneration
2069+
from transformers.models.internvl.modeling_internvl import InternVLModel
2070+
from transformers.models.internvl.modeling_internvl import InternVLVisionLayer
2071+
from transformers.models.internvl.modeling_internvl import InternVLVisionModel
2072+
from transformers.models.internvl.modeling_internvl import InternVLVisionRMSNorm
20672073

2074+
from liger_kernel.transformers.layer_norm import LigerLayerNorm
20682075
from liger_kernel.transformers.model.internvl import lce_forward as internvl_lce_forward
2076+
from liger_kernel.transformers.rms_norm import LigerRMSNorm
2077+
2078+
if layer_norm and model is None:
2079+
modeling_internvl.nn.LayerNorm = LigerLayerNorm
20692080

20702081
if cross_entropy:
2071-
logger.warning(TRANSFORMER_DEPRECATION_WARNING)
2072-
modeling_internvl.nn.CrossEntropyLoss = LigerCrossEntropyLoss
2082+
logger.info("Apply liger cross entropy")
2083+
2084+
from transformers.loss.loss_utils import nn
2085+
2086+
nn.functional.cross_entropy = liger_cross_entropy
20732087
if fused_linear_cross_entropy:
20742088
modeling_internvl.InternVLForConditionalGeneration.forward = internvl_lce_forward
20752089
if rms_norm:
20762090
modeling_internvl.InternVLVisionRMSNorm = LigerRMSNorm
20772091

20782092
if model is not None:
2079-
text_model_name, vision_model_name = model.config.text_config.model_type, model.config.vision_config.model_type
2093+
# The model instance already exists, so we need to additionally patch the
2094+
# instance variables that reference already-instantiated modules
2095+
if isinstance(model, (InternVLForConditionalGeneration, InternVLModel)):
2096+
# NOTE: language_model and visual properties can be accessed throught conditional class.
2097+
text_model = model.language_model
2098+
vision_model: InternVLVisionModel = model.vision_tower
2099+
else:
2100+
raise TypeError(
2101+
f"Unsupported internvl model type. `model` must be `InternVLForConditionalGeneration`, `InternVLModel`. Got: {type(model)}"
2102+
)
2103+
2104+
text_model_name = model.config.text_config.model_type
20802105
text_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(text_model_name, None)
2081-
vision_liger_fn = MODEL_TYPE_TO_APPLY_LIGER_FN.get(vision_model_name, None)
20822106

20832107
kwargs = {"cross_entropy": False, "fused_linear_cross_entropy": False, **kwargs} | {"rms_norm": rms_norm}
20842108
if text_liger_fn:
@@ -2091,25 +2115,33 @@ def apply_liger_kernel_to_internvl(
20912115
f"These parameters are not supported by {text_model_name}. Enter the remaining {list(text_kwargs.keys())} except for {list(remain_params)}\n"
20922116
f"Parameters accepted by {text_model_name}: {list(accept_params.keys())}"
20932117
)
2094-
text_kwargs["model"] = model.language_model
2118+
text_kwargs["model"] = text_model
20952119
text_liger_fn(**text_kwargs)
20962120
elif text_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
20972121
logger.warning(f"{text_model_name} is not supported by Liger kernel.")
20982122

2099-
if vision_liger_fn:
2100-
accept_params = inspect.signature(vision_liger_fn).parameters
2101-
remain_params = set(kwargs) - (set(accept_params) & set(kwargs))
2102-
vision_kwargs = {k: v for k, v in kwargs.items() if k not in remain_params}
2123+
# Patch vision model RMSNorm layers
2124+
if rms_norm:
2125+
for encoder_layer in vision_model.encoder.layer:
2126+
encoder_layer: InternVLVisionLayer
2127+
if isinstance(encoder_layer.attention.q_norm, InternVLVisionRMSNorm):
2128+
_patch_rms_norm_module(encoder_layer.attention.q_norm)
2129+
if isinstance(encoder_layer.attention.k_norm, InternVLVisionRMSNorm):
2130+
_patch_rms_norm_module(encoder_layer.attention.k_norm)
21032131

2104-
if remain_params:
2105-
logger.warning(
2106-
f"These parameters are not supported by {vision_model_name}. Enter the remaining {list(vision_kwargs.keys())} except for {list(remain_params)}\n"
2107-
f"Parameters accepted by {vision_model_name}: {list(accept_params.keys())}"
2108-
)
2109-
vision_kwargs["model"] = model.vision_tower
2110-
vision_liger_fn(**vision_kwargs)
2111-
elif vision_model_name not in MODEL_TYPE_TO_APPLY_LIGER_FN:
2112-
logger.warning(f"{vision_model_name} is not supported by Liger kernel.")
2132+
# Patch vision model LayerNorm layers
2133+
if layer_norm:
2134+
# Patch layernorm
2135+
if isinstance(vision_model.layernorm, torch_nn.LayerNorm):
2136+
_patch_layer_norm_module(vision_model.layernorm)
2137+
2138+
# Patch encoder layers
2139+
for encoder_layer in vision_model.encoder.layer:
2140+
encoder_layer: InternVLVisionLayer
2141+
if isinstance(encoder_layer.layernorm_before, torch_nn.LayerNorm):
2142+
_patch_layer_norm_module(encoder_layer.layernorm_before)
2143+
if isinstance(encoder_layer.layernorm_after, torch_nn.LayerNorm):
2144+
_patch_layer_norm_module(encoder_layer.layernorm_after)
21132145

21142146

21152147
def apply_liger_kernel_to_smolvlm(

test/transformers/test_monkey_patch.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1637,21 +1637,30 @@ def test_apply_liger_kernel_to_instance_for_internvl():
16371637

16381638
# Instantiate a dummy model
16391639
config = transformers.models.internvl.configuration_internvl.InternVLConfig(
1640-
torch_dtype=torch.bfloat16,
1641-
rms_norm_eps=1e-5,
1642-
hidden_size=32,
1643-
intermediate_size=48,
1644-
hidden_act="silu",
1645-
num_hidden_layers=2,
1646-
num_attention_heads=2,
1647-
max_position_embeddings=128,
1648-
vocab_size=1000,
1640+
dtype=torch.bfloat16,
1641+
text_config={
1642+
"rms_norm_eps": 1e-5,
1643+
"hidden_size": 256, # 1024
1644+
"intermediate_size": 1024, # 4096
1645+
"hidden_act": "silu",
1646+
"num_hidden_layers": 4, # 24
1647+
"num_attention_heads": 4, # 16
1648+
"num_key_value_heads": 2, # 16
1649+
"max_position_embeddings": 4096, # 8192
1650+
"vocab_size": 32000, # 151936
1651+
"bos_token_id": 1,
1652+
"eos_token_id": 2,
1653+
"pad_token_id": 2,
1654+
"tie_word_embeddings": False,
1655+
},
16491656
vision_config={
1650-
"depth": 4,
1651-
"embed_dim": 128,
1652-
"num_heads": 8,
1653-
"hidden_size": 1024,
1657+
"hidden_size": 256, # 1024
1658+
"intermediate_size": 1024, # 4096
1659+
"num_hidden_layers": 4, # 24
1660+
"num_attention_heads": 4, # 16
16541661
},
1662+
image_token_id=10,
1663+
attn_implementation="sdpa", # default value, pytorch native attention
16551664
)
16561665
dummy_model_instance = InternVLForConditionalGeneration._from_config(config)
16571666

@@ -1692,7 +1701,7 @@ def test_apply_liger_kernel_to_instance_for_smolvlm2():
16921701

16931702
# Instantiate a dummy model
16941703
config = transformers.models.smolvlm.configuration_smolvlm.SmolVLMConfig(
1695-
torch_dtype=torch.bfloat16,
1704+
dtype=torch.bfloat16,
16961705
text_config={
16971706
"rms_norm_eps": 1e-5,
16981707
"hidden_size": 576,

test/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -585,10 +585,12 @@ def revert_liger_kernel_to_internvl(model_config: MiniModelConfig):
585585
"""
586586
Revert all Liger kernel patches applied to InternVL.
587587
"""
588+
import torch.nn as nn
588589

589590
from transformers.models.internvl import modeling_internvl
590591
from transformers.models.qwen2 import modeling_qwen2
591592

593+
importlib.reload(nn)
592594
importlib.reload(modeling_internvl)
593595
importlib.reload(modeling_qwen2)
594596

0 commit comments

Comments
 (0)