@@ -967,62 +967,6 @@ def forward(self, pixel_values: torch.LongTensor):
967967 return last_hidden_state
968968
969969
970- CHAMELEON_VQ_START_DOCSTRING = r"""
971- This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
972- library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
973- etc.)
974-
975- This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
976- Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
977- and behavior.
978-
979- Parameters:
980- config ([`ChameleonVQVAEConfig`]):
981- Model configuration class with all the parameters of the model. Initializing with a config file does not
982- load the weights associated with the model, only the configuration. Check out the
983- [`~PreTrainedModel.from_pretrained`] method to load the model weights.
984- """
985-
986-
987- @add_start_docstrings (
988- """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
989- This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
990- [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
991- """ ,
992- CHAMELEON_VQ_START_DOCSTRING ,
993- )
994- class ChameleonVQVAE (PreTrainedModel ):
995- config_class = ChameleonVQVAEConfig
996- _no_split_modules = ["ChameleonVQVAEVectorQuantizer" ]
997-
998- def _init_weights (self , module ):
999- std = self .config .initializer_range
1000- if isinstance (module , nn .Embedding ):
1001- module .weight .data .normal_ (mean = 0.0 , std = std )
1002- elif isinstance (module , nn .GroupNorm ):
1003- module .bias .data .zero_ ()
1004- module .weight .data .fill_ (1.0 )
1005- elif isinstance (module , (nn .Linear , nn .Conv2d )):
1006- module .weight .data .normal_ (mean = 0.0 , std = std )
1007- if module .bias is not None :
1008- module .bias .data .zero_ ()
1009-
1010- def __init__ (self , config : ChameleonVQVAEConfig ):
1011- super ().__init__ (config )
1012-
1013- self .encoder = ChameleonVQVAEEncoder (config )
1014- self .quantize = ChameleonVQVAEVectorQuantizer (config )
1015- self .quant_conv = torch .nn .Conv2d (config .latent_channels , config .embed_dim , 1 )
1016- self .post_quant_conv = torch .nn .Conv2d (config .embed_dim , config .latent_channels , 1 )
1017- self .eval () # Chameleon's VQ model is frozen
1018-
1019- def encode (self , pixel_values : torch .LongTensor ):
1020- hidden_states = self .encoder (pixel_values )
1021- hidden_states = self .quant_conv (hidden_states )
1022- quant , emb_loss , indices = self .quantize (hidden_states )
1023- return quant , emb_loss , indices
1024-
1025-
1026970class ChameleonImageVocabularyMapping :
1027971 """
1028972 A class for mapping discrete image tokens from VQGAN to BPE tokens.
@@ -1118,6 +1062,62 @@ def _init_weights(self, module):
11181062 module .weight .data [module .padding_idx ].zero_ ()
11191063
11201064
1065+ CHAMELEON_VQ_START_DOCSTRING = r"""
1066+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1067+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1068+ etc.)
1069+
1070+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1071+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1072+ and behavior.
1073+
1074+ Parameters:
1075+ config ([`ChameleonVQVAEConfig`]):
1076+ Model configuration class with all the parameters of the model. Initializing with a config file does not
1077+ load the weights associated with the model, only the configuration. Check out the
1078+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
1079+ """
1080+
1081+
1082+ @add_start_docstrings (
1083+ """The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
1084+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
1085+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
1086+ """ ,
1087+ CHAMELEON_VQ_START_DOCSTRING ,
1088+ )
1089+ class ChameleonVQVAE (ChameleonPreTrainedModel ):
1090+ config_class = ChameleonVQVAEConfig
1091+ _no_split_modules = ["ChameleonVQVAEVectorQuantizer" ]
1092+
1093+ def _init_weights (self , module ):
1094+ std = self .config .initializer_range
1095+ if isinstance (module , nn .Embedding ):
1096+ module .weight .data .normal_ (mean = 0.0 , std = std )
1097+ elif isinstance (module , nn .GroupNorm ):
1098+ module .bias .data .zero_ ()
1099+ module .weight .data .fill_ (1.0 )
1100+ elif isinstance (module , (nn .Linear , nn .Conv2d )):
1101+ module .weight .data .normal_ (mean = 0.0 , std = std )
1102+ if module .bias is not None :
1103+ module .bias .data .zero_ ()
1104+
1105+ def __init__ (self , config : ChameleonVQVAEConfig ):
1106+ super ().__init__ (config )
1107+
1108+ self .encoder = ChameleonVQVAEEncoder (config )
1109+ self .quantize = ChameleonVQVAEVectorQuantizer (config )
1110+ self .quant_conv = torch .nn .Conv2d (config .latent_channels , config .embed_dim , 1 )
1111+ self .post_quant_conv = torch .nn .Conv2d (config .embed_dim , config .latent_channels , 1 )
1112+ self .eval () # Chameleon's VQ model is frozen
1113+
1114+ def encode (self , pixel_values : torch .LongTensor ):
1115+ hidden_states = self .encoder (pixel_values )
1116+ hidden_states = self .quant_conv (hidden_states )
1117+ quant , emb_loss , indices = self .quantize (hidden_states )
1118+ return quant , emb_loss , indices
1119+
1120+
11211121CHAMELEON_INPUTS_DOCSTRING = r"""
11221122 Args:
11231123 input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1211,7 +1211,7 @@ def __init__(self, config: ChameleonConfig):
12111211 [decoder_layer (config , layer_idx ) for layer_idx in range (config .num_hidden_layers )]
12121212 )
12131213 self .norm = ChameleonRMSNorm (config .hidden_size , eps = config .rms_norm_eps )
1214- self .vqmodel = ChameleonVQVAE (config .vq_config )
1214+ self .vqmodel = ChameleonVQVAE . _from_config (config .vq_config )
12151215 self .gradient_checkpointing = False
12161216
12171217 # Initialize weights and apply final processing
0 commit comments