Skip to content
2 changes: 1 addition & 1 deletion src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ class LlamaPreTrainedModel(PreTrainedModel):
config: LlamaConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_no_split_modules = ["LlamaDecoderLayer"]
_no_split_modules = ["LlamaDecoderLayer", "LlamaRMSNorm"]
_skip_keys_device_placement = ["past_key_values"]
_supports_flash_attn = True
_supports_sdpa = True
Expand Down