diff --git a/src/transformers/models/chameleon/modeling_chameleon.py b/src/transformers/models/chameleon/modeling_chameleon.py index b003f9e4e8f5..dd0d0a5582f2 100644 --- a/src/transformers/models/chameleon/modeling_chameleon.py +++ b/src/transformers/models/chameleon/modeling_chameleon.py @@ -817,7 +817,11 @@ class ChameleonPreTrainedModel(PreTrainedModel): ) class ChameleonVQVAE(ChameleonPreTrainedModel): config: ChameleonVQVAEConfig - _no_split_modules = ["ChameleonVQVAEVectorQuantizer"] + _no_split_modules = [ + "ChameleonVQVAEVectorQuantizer", + "ChameleonVQVAEEncoderAttnBlock", + "ChameleonVQVAEEncoderResnetBlock", + ] def __init__(self, config: ChameleonVQVAEConfig): super().__init__(config)