diff --git a/src/transformers/models/arcee/modeling_arcee.py b/src/transformers/models/arcee/modeling_arcee.py index 7d9afa092def..3983d4ca0322 100644 --- a/src/transformers/models/arcee/modeling_arcee.py +++ b/src/transformers/models/arcee/modeling_arcee.py @@ -312,7 +312,7 @@ class ArceePreTrainedModel(PreTrainedModel): config: ArceeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["ArceeDecoderLayer"] + _no_split_modules = ["ArceeDecoderLayer", "ArceeRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/aria/modeling_aria.py b/src/transformers/models/aria/modeling_aria.py index 7303ca2e9c50..d7f0854b2715 100644 --- a/src/transformers/models/aria/modeling_aria.py +++ b/src/transformers/models/aria/modeling_aria.py @@ -649,7 +649,7 @@ class AriaPreTrainedModel(PreTrainedModel): config: AriaConfig base_model_prefix = "" supports_gradient_checkpointing = True - _no_split_modules = ["AriaDecoderLayer"] + _no_split_modules = ["AriaDecoderLayer", "AriaRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/bitnet/modeling_bitnet.py b/src/transformers/models/bitnet/modeling_bitnet.py index 0af7c794b155..22df51be84ab 100644 --- a/src/transformers/models/bitnet/modeling_bitnet.py +++ b/src/transformers/models/bitnet/modeling_bitnet.py @@ -311,7 +311,7 @@ class BitNetPreTrainedModel(PreTrainedModel): config: BitNetConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["BitNetDecoderLayer"] + _no_split_modules = ["BitNetDecoderLayer", "BitNetRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/cohere/modeling_cohere.py b/src/transformers/models/cohere/modeling_cohere.py index 1dfa0ce0be33..bad3bcd56ce0 100644 --- a/src/transformers/models/cohere/modeling_cohere.py +++ b/src/transformers/models/cohere/modeling_cohere.py @@ -344,7 +344,7 @@ class CoherePreTrainedModel(PreTrainedModel): config: CohereConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["CohereDecoderLayer"] + _no_split_modules = ["CohereDecoderLayer", "CohereRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/cohere2/modeling_cohere2.py b/src/transformers/models/cohere2/modeling_cohere2.py index bab804aab67e..d19e124658a8 100644 --- a/src/transformers/models/cohere2/modeling_cohere2.py +++ b/src/transformers/models/cohere2/modeling_cohere2.py @@ -319,7 +319,7 @@ class Cohere2PreTrainedModel(PreTrainedModel): config: Cohere2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Cohere2DecoderLayer"] + _no_split_modules = ["Cohere2DecoderLayer", "Cohere2RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index fa291e768957..53c5f3c25a8b 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -455,7 +455,7 @@ class DeepseekV2PreTrainedModel(PreTrainedModel): config: DeepseekV2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["DeepseekV2DecoderLayer"] + _no_split_modules = ["DeepseekV2DecoderLayer", "DeepseekV2RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index c4552fb218ee..6d2ab2f8c7fb 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -501,7 +501,7 @@ class DeepseekV3PreTrainedModel(PreTrainedModel): config: DeepseekV3Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["DeepseekV3DecoderLayer"] + _no_split_modules = ["DeepseekV3DecoderLayer", "DeepseekV3RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/diffllama/modeling_diffllama.py b/src/transformers/models/diffllama/modeling_diffllama.py index 094cc375057f..ea0004641186 100644 --- a/src/transformers/models/diffllama/modeling_diffllama.py +++ b/src/transformers/models/diffllama/modeling_diffllama.py @@ -532,7 +532,7 @@ class DiffLlamaPreTrainedModel(PreTrainedModel): config: DiffLlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["DiffLlamaDecoderLayer"] + _no_split_modules = ["DiffLlamaDecoderLayer", "DiffLlamaRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/doge/modeling_doge.py b/src/transformers/models/doge/modeling_doge.py index 5822cad62017..5a513dffa63c 100644 --- a/src/transformers/models/doge/modeling_doge.py +++ b/src/transformers/models/doge/modeling_doge.py @@ -486,7 +486,7 @@ class DogePreTrainedModel(PreTrainedModel): config: DogeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["DogeDecoderLayer"] + _no_split_modules = ["DogeDecoderLayer", "DogeRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = False _supports_sdpa = True diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index ea500c064512..ad960dc0ff90 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -417,7 +417,7 @@ class Dots1PreTrainedModel(PreTrainedModel): config: Dots1Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Dots1DecoderLayer"] + _no_split_modules = ["Dots1DecoderLayer", "Dots1RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/ernie4_5/modeling_ernie4_5.py b/src/transformers/models/ernie4_5/modeling_ernie4_5.py index 13ec6fb3a3b6..89e69c97e249 100644 --- a/src/transformers/models/ernie4_5/modeling_ernie4_5.py +++ b/src/transformers/models/ernie4_5/modeling_ernie4_5.py @@ -310,7 +310,7 @@ class Ernie4_5PreTrainedModel(PreTrainedModel): config: Ernie4_5Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Ernie4_5DecoderLayer"] + _no_split_modules = ["Ernie4_5DecoderLayer", "Ernie4_5RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/gemma/modeling_gemma.py b/src/transformers/models/gemma/modeling_gemma.py index ef0a688d4608..5b3cfcc22910 100644 --- a/src/transformers/models/gemma/modeling_gemma.py +++ b/src/transformers/models/gemma/modeling_gemma.py @@ -310,7 +310,7 @@ class GemmaPreTrainedModel(PreTrainedModel): config: GemmaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["GemmaDecoderLayer"] + _no_split_modules = ["GemmaDecoderLayer", "GemmaRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/gemma2/modeling_gemma2.py b/src/transformers/models/gemma2/modeling_gemma2.py index 2a218338384a..5e76f25c5e38 100644 --- a/src/transformers/models/gemma2/modeling_gemma2.py +++ b/src/transformers/models/gemma2/modeling_gemma2.py @@ -340,7 +340,7 @@ class Gemma2PreTrainedModel(PreTrainedModel): config: Gemma2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Gemma2DecoderLayer"] + _no_split_modules = ["Gemma2DecoderLayer", "Gemma2RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/glm/modeling_glm.py b/src/transformers/models/glm/modeling_glm.py index 59c9f39da527..2b76445831a6 100644 --- a/src/transformers/models/glm/modeling_glm.py +++ b/src/transformers/models/glm/modeling_glm.py @@ -326,7 +326,7 @@ class GlmPreTrainedModel(PreTrainedModel): config: GlmConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["GlmDecoderLayer"] + _no_split_modules = ["GlmDecoderLayer", "GlmRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/glm4/modeling_glm4.py b/src/transformers/models/glm4/modeling_glm4.py index dafab297f566..a76b0f53f7a7 100644 --- a/src/transformers/models/glm4/modeling_glm4.py +++ b/src/transformers/models/glm4/modeling_glm4.py @@ -330,7 +330,7 @@ class Glm4PreTrainedModel(PreTrainedModel): config: Glm4Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Glm4DecoderLayer"] + _no_split_modules = ["Glm4DecoderLayer", "Glm4RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index cb695ffbe638..34d8836f1729 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -402,7 +402,7 @@ class Glm4MoePreTrainedModel(PreTrainedModel): config: Glm4MoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Glm4MoeDecoderLayer"] + _no_split_modules = ["Glm4MoeDecoderLayer", "Glm4MoeRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 0d5c936e8adc..34c620c13112 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -393,7 +393,7 @@ class GptOssPreTrainedModel(PreTrainedModel): config: GptOssConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["GptOssDecoderLayer"] + _no_split_modules = ["GptOssDecoderLayer", "GptOssRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = False diff --git a/src/transformers/models/granite/modeling_granite.py b/src/transformers/models/granite/modeling_granite.py index 846865c55508..bb651329ae75 100644 --- a/src/transformers/models/granite/modeling_granite.py +++ b/src/transformers/models/granite/modeling_granite.py @@ -306,7 +306,7 @@ class GranitePreTrainedModel(PreTrainedModel): config: GraniteConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["GraniteDecoderLayer"] + _no_split_modules = ["GraniteDecoderLayer", "GraniteRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/helium/modeling_helium.py b/src/transformers/models/helium/modeling_helium.py index 9f4a2e73affd..5b44a8b759f5 100644 --- a/src/transformers/models/helium/modeling_helium.py +++ b/src/transformers/models/helium/modeling_helium.py @@ -311,7 +311,7 @@ class HeliumPreTrainedModel(PreTrainedModel): config: HeliumConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["HeliumDecoderLayer"] + _no_split_modules = ["HeliumDecoderLayer", "HeliumRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 5ea4314968e2..744ce70ab00a 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -577,7 +577,7 @@ class Lfm2PreTrainedModel(PreTrainedModel): config: Lfm2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Lfm2DecoderLayer"] + _no_split_modules = ["Lfm2DecoderLayer", "Lfm2RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/llama/modeling_llama.py b/src/transformers/models/llama/modeling_llama.py index f43a5fc9b523..3c2c39b62c5f 100644 --- a/src/transformers/models/llama/modeling_llama.py +++ b/src/transformers/models/llama/modeling_llama.py @@ -316,7 +316,7 @@ class LlamaPreTrainedModel(PreTrainedModel): config: LlamaConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["LlamaDecoderLayer"] + _no_split_modules = ["LlamaDecoderLayer", "LlamaRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 633e053e2d54..efad290bc90d 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -580,7 +580,7 @@ class MiniMaxPreTrainedModel(PreTrainedModel): config: MiniMaxConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["MiniMaxDecoderLayer"] + _no_split_modules = ["MiniMaxDecoderLayer", "MiniMaxRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/mistral/modeling_mistral.py b/src/transformers/models/mistral/modeling_mistral.py index 5b7c7b2c1790..ac63b0100dfc 100644 --- a/src/transformers/models/mistral/modeling_mistral.py +++ b/src/transformers/models/mistral/modeling_mistral.py @@ -253,7 +253,7 @@ class MistralPreTrainedModel(PreTrainedModel): config: MistralConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["MistralDecoderLayer"] + _no_split_modules = ["MistralDecoderLayer", "MistralRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 2412092aeb86..8297364f6e60 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -383,7 +383,7 @@ class MixtralPreTrainedModel(PreTrainedModel): config: MixtralConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["MixtralDecoderLayer"] + _no_split_modules = ["MixtralDecoderLayer", "MixtralRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/olmo/modeling_olmo.py b/src/transformers/models/olmo/modeling_olmo.py index 2d6e6e7092cb..a799e7112161 100644 --- a/src/transformers/models/olmo/modeling_olmo.py +++ b/src/transformers/models/olmo/modeling_olmo.py @@ -293,7 +293,7 @@ class OlmoPreTrainedModel(PreTrainedModel): config: OlmoConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["OlmoDecoderLayer"] + _no_split_modules = ["OlmoDecoderLayer", "OlmoRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/olmo2/modeling_olmo2.py b/src/transformers/models/olmo2/modeling_olmo2.py index 3fe4cfaf91de..4e96ef8397d2 100644 --- a/src/transformers/models/olmo2/modeling_olmo2.py +++ b/src/transformers/models/olmo2/modeling_olmo2.py @@ -298,7 +298,7 @@ class Olmo2PreTrainedModel(PreTrainedModel): config: Olmo2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Olmo2DecoderLayer"] + _no_split_modules = ["Olmo2DecoderLayer", "Olmo2RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/phi/modeling_phi.py b/src/transformers/models/phi/modeling_phi.py index 165a2b887423..7ecb3e07cca6 100644 --- a/src/transformers/models/phi/modeling_phi.py +++ b/src/transformers/models/phi/modeling_phi.py @@ -297,7 +297,7 @@ class PhiPreTrainedModel(PreTrainedModel): config: PhiConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["PhiDecoderLayer"] + _no_split_modules = ["PhiDecoderLayer", "PhiRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/phi3/modeling_phi3.py b/src/transformers/models/phi3/modeling_phi3.py index 23820075a020..ec6f6bc731c5 100644 --- a/src/transformers/models/phi3/modeling_phi3.py +++ b/src/transformers/models/phi3/modeling_phi3.py @@ -284,7 +284,7 @@ class Phi3PreTrainedModel(PreTrainedModel): config: Phi3Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Phi3DecoderLayer"] + _no_split_modules = ["Phi3DecoderLayer", "Phi3RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py index ad2ef3e07124..e0f96db015c9 100644 --- a/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +++ b/src/transformers/models/phi4_multimodal/modeling_phi4_multimodal.py @@ -1518,7 +1518,7 @@ class Phi4MultimodalPreTrainedModel(PreTrainedModel): config: Phi4MultimodalConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Phi4MultimodalDecoderLayer"] + _no_split_modules = ["Phi4MultimodalDecoderLayer", "Phi4MultimodalRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/qwen2/modeling_qwen2.py b/src/transformers/models/qwen2/modeling_qwen2.py index 2fcb44372fe4..fbec404a8f33 100644 --- a/src/transformers/models/qwen2/modeling_qwen2.py +++ b/src/transformers/models/qwen2/modeling_qwen2.py @@ -256,7 +256,7 @@ class Qwen2PreTrainedModel(PreTrainedModel): config: Qwen2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Qwen2DecoderLayer"] + _no_split_modules = ["Qwen2DecoderLayer", "Qwen2RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/qwen3/modeling_qwen3.py b/src/transformers/models/qwen3/modeling_qwen3.py index 81b16c4ee6b6..b9db22b9f026 100644 --- a/src/transformers/models/qwen3/modeling_qwen3.py +++ b/src/transformers/models/qwen3/modeling_qwen3.py @@ -282,7 +282,7 @@ class Qwen3PreTrainedModel(PreTrainedModel): config: Qwen3Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Qwen3DecoderLayer"] + _no_split_modules = ["Qwen3DecoderLayer", "Qwen3RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index 2056e7c76a3a..ead0950b14f6 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -406,7 +406,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): config: Qwen3MoeConfig base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Qwen3MoeDecoderLayer"] + _no_split_modules = ["Qwen3MoeDecoderLayer", "Qwen3MoeRMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/smollm3/modeling_smollm3.py b/src/transformers/models/smollm3/modeling_smollm3.py index 1e08e288193b..a2b60b9a5a04 100644 --- a/src/transformers/models/smollm3/modeling_smollm3.py +++ b/src/transformers/models/smollm3/modeling_smollm3.py @@ -286,7 +286,7 @@ class SmolLM3PreTrainedModel(PreTrainedModel): config: SmolLM3Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["SmolLM3DecoderLayer"] + _no_split_modules = ["SmolLM3DecoderLayer", "SmolLM3RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True diff --git a/src/transformers/models/starcoder2/modeling_starcoder2.py b/src/transformers/models/starcoder2/modeling_starcoder2.py index dfdfec22ca99..2ce408450de8 100644 --- a/src/transformers/models/starcoder2/modeling_starcoder2.py +++ b/src/transformers/models/starcoder2/modeling_starcoder2.py @@ -292,7 +292,7 @@ class Starcoder2PreTrainedModel(PreTrainedModel): config: Starcoder2Config base_model_prefix = "model" supports_gradient_checkpointing = True - _no_split_modules = ["Starcoder2DecoderLayer"] + _no_split_modules = ["Starcoder2DecoderLayer", "Starcoder2RMSNorm"] _skip_keys_device_placement = ["past_key_values"] _supports_flash_attn = True _supports_sdpa = True