Skip to content

Commit 84a6789

Browse files
authored
Enable different torch dtype in sub models (#34873)
* fix * fix test * add tests * add more tests * fix tests * supposed to be a torch.dtype test * handle BC and make fp32 default
1 parent 8708917 commit 84a6789

File tree

5 files changed

+155
-68
lines changed

5 files changed

+155
-68
lines changed

src/transformers/configuration_utils.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -994,8 +994,11 @@ def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
994994
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
995995
string, which can then be stored in the json format.
996996
"""
997-
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
998-
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
997+
if d.get("torch_dtype", None) is not None:
998+
if isinstance(d["torch_dtype"], dict):
999+
d["torch_dtype"] = {k: str(v).split(".")[-1] for k, v in d["torch_dtype"].items()}
1000+
elif not isinstance(d["torch_dtype"], str):
1001+
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
9991002
for value in d.values():
10001003
if isinstance(value, dict):
10011004
self.dict_torch_dtype_to_str(value)

src/transformers/modeling_utils.py

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1312,11 +1312,10 @@ def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
13121312
"`PretrainedConfig`. To create a model from a pretrained model use "
13131313
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
13141314
)
1315-
# Save config and origin of the pretrained weights if given in model
13161315
if not getattr(config, "_attn_implementation_autoset", False):
1317-
config = self._autoset_attn_implementation(
1318-
config, torch_dtype=torch.get_default_dtype(), check_device_map=False
1319-
)
1316+
# config usually has a `torch_dtype` but we need the next line for the `no_super_init` tests
1317+
dtype = config.torch_dtype if hasattr(config, "torch_dtype") else torch.get_default_dtype()
1318+
config = self._autoset_attn_implementation(config, torch_dtype=dtype, check_device_map=False)
13201319
self.config = config
13211320

13221321
# for initialization of the loss
@@ -1411,7 +1410,10 @@ def _from_config(cls, config, **kwargs):
14111410
# when we init a model from within another model (e.g. VLMs) and dispatch on FA2
14121411
# a warning is raised that dtype should be fp16. Since we never pass dtype from within
14131412
# modeling code, we can try to infer it here same way as done in `from_pretrained`
1414-
torch_dtype = kwargs.pop("torch_dtype", torch.get_default_dtype())
1413+
torch_dtype = kwargs.pop("torch_dtype", config.torch_dtype)
1414+
if isinstance(torch_dtype, str):
1415+
torch_dtype = getattr(torch, torch_dtype)
1416+
14151417
use_flash_attention_2 = kwargs.pop("use_flash_attention_2", False)
14161418

14171419
# override default dtype if needed
@@ -4020,11 +4022,37 @@ def from_pretrained(
40204022
)
40214023
elif hasattr(torch, torch_dtype):
40224024
torch_dtype = getattr(torch, torch_dtype)
4023-
else:
4024-
raise ValueError(
4025-
f'`torch_dtype` can be one of: `torch.dtype`, `"auto"` or a string of a valid `torch.dtype`, but received {torch_dtype}'
4026-
)
4025+
for sub_config_key in config.sub_configs.keys():
4026+
sub_config = getattr(config, sub_config_key)
4027+
sub_config.torch_dtype = torch_dtype
4028+
elif isinstance(torch_dtype, torch.dtype):
4029+
pass
4030+
elif isinstance(torch_dtype, dict):
4031+
for key, curr_dtype in torch_dtype.items():
4032+
if hasattr(config, key):
4033+
value = getattr(config, key)
4034+
value.torch_dtype = curr_dtype
4035+
# main torch dtype for modules that aren't part of any sub-config
4036+
torch_dtype = torch_dtype.get("")
4037+
config.torch_dtype = torch_dtype
4038+
if isinstance(torch_dtype, str) and hasattr(torch, torch_dtype):
4039+
torch_dtype = getattr(torch, torch_dtype)
4040+
elif torch_dtype is None:
4041+
torch_dtype = torch.float32
4042+
else:
4043+
raise ValueError(
4044+
f"`torch_dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `torch_dtype` "
4045+
f"for each sub-config in composite configs, but received {torch_dtype}"
4046+
)
4047+
40274048
dtype_orig = cls._set_default_torch_dtype(torch_dtype)
4049+
else:
4050+
# set fp32 as the default dtype for BC
4051+
default_dtype = str(torch.get_default_dtype()).split(".")[-1]
4052+
config.torch_dtype = default_dtype
4053+
for key in config.sub_configs.keys():
4054+
value = getattr(config, key)
4055+
value.torch_dtype = default_dtype
40284056

40294057
# Check if `_keep_in_fp32_modules` is not None
40304058
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (

src/transformers/models/chameleon/modeling_chameleon.py

Lines changed: 57 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -967,62 +967,6 @@ def forward(self, pixel_values: torch.LongTensor):
967967
return last_hidden_state
968968

969969

970-
CHAMELEON_VQ_START_DOCSTRING = r"""
971-
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
972-
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
973-
etc.)
974-
975-
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
976-
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
977-
and behavior.
978-
979-
Parameters:
980-
config ([`ChameleonVQVAEConfig`]):
981-
Model configuration class with all the parameters of the model. Initializing with a config file does not
982-
load the weights associated with the model, only the configuration. Check out the
983-
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
984-
"""
985-
986-
987-
@add_start_docstrings(
988-
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
989-
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
990-
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
991-
""",
992-
CHAMELEON_VQ_START_DOCSTRING,
993-
)
994-
class ChameleonVQVAE(PreTrainedModel):
995-
config_class = ChameleonVQVAEConfig
996-
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
997-
998-
def _init_weights(self, module):
999-
std = self.config.initializer_range
1000-
if isinstance(module, nn.Embedding):
1001-
module.weight.data.normal_(mean=0.0, std=std)
1002-
elif isinstance(module, nn.GroupNorm):
1003-
module.bias.data.zero_()
1004-
module.weight.data.fill_(1.0)
1005-
elif isinstance(module, (nn.Linear, nn.Conv2d)):
1006-
module.weight.data.normal_(mean=0.0, std=std)
1007-
if module.bias is not None:
1008-
module.bias.data.zero_()
1009-
1010-
def __init__(self, config: ChameleonVQVAEConfig):
1011-
super().__init__(config)
1012-
1013-
self.encoder = ChameleonVQVAEEncoder(config)
1014-
self.quantize = ChameleonVQVAEVectorQuantizer(config)
1015-
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
1016-
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
1017-
self.eval() # Chameleon's VQ model is frozen
1018-
1019-
def encode(self, pixel_values: torch.LongTensor):
1020-
hidden_states = self.encoder(pixel_values)
1021-
hidden_states = self.quant_conv(hidden_states)
1022-
quant, emb_loss, indices = self.quantize(hidden_states)
1023-
return quant, emb_loss, indices
1024-
1025-
1026970
class ChameleonImageVocabularyMapping:
1027971
"""
1028972
A class for mapping discrete image tokens from VQGAN to BPE tokens.
@@ -1118,6 +1062,62 @@ def _init_weights(self, module):
11181062
module.weight.data[module.padding_idx].zero_()
11191063

11201064

1065+
CHAMELEON_VQ_START_DOCSTRING = r"""
1066+
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
1067+
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
1068+
etc.)
1069+
1070+
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
1071+
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
1072+
and behavior.
1073+
1074+
Parameters:
1075+
config ([`ChameleonVQVAEConfig`]):
1076+
Model configuration class with all the parameters of the model. Initializing with a config file does not
1077+
load the weights associated with the model, only the configuration. Check out the
1078+
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
1079+
"""
1080+
1081+
1082+
@add_start_docstrings(
1083+
"""The VQ-VAE model used in Chameleon for encoding/decoding images into discrete tokens.
1084+
This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
1085+
[ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv Taigman](https://arxiv.org/abs/2203.13131).
1086+
""",
1087+
CHAMELEON_VQ_START_DOCSTRING,
1088+
)
1089+
class ChameleonVQVAE(ChameleonPreTrainedModel):
1090+
config_class = ChameleonVQVAEConfig
1091+
_no_split_modules = ["ChameleonVQVAEVectorQuantizer"]
1092+
1093+
def _init_weights(self, module):
1094+
std = self.config.initializer_range
1095+
if isinstance(module, nn.Embedding):
1096+
module.weight.data.normal_(mean=0.0, std=std)
1097+
elif isinstance(module, nn.GroupNorm):
1098+
module.bias.data.zero_()
1099+
module.weight.data.fill_(1.0)
1100+
elif isinstance(module, (nn.Linear, nn.Conv2d)):
1101+
module.weight.data.normal_(mean=0.0, std=std)
1102+
if module.bias is not None:
1103+
module.bias.data.zero_()
1104+
1105+
def __init__(self, config: ChameleonVQVAEConfig):
1106+
super().__init__(config)
1107+
1108+
self.encoder = ChameleonVQVAEEncoder(config)
1109+
self.quantize = ChameleonVQVAEVectorQuantizer(config)
1110+
self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
1111+
self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
1112+
self.eval() # Chameleon's VQ model is frozen
1113+
1114+
def encode(self, pixel_values: torch.LongTensor):
1115+
hidden_states = self.encoder(pixel_values)
1116+
hidden_states = self.quant_conv(hidden_states)
1117+
quant, emb_loss, indices = self.quantize(hidden_states)
1118+
return quant, emb_loss, indices
1119+
1120+
11211121
CHAMELEON_INPUTS_DOCSTRING = r"""
11221122
Args:
11231123
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1211,7 +1211,7 @@ def __init__(self, config: ChameleonConfig):
12111211
[decoder_layer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
12121212
)
12131213
self.norm = ChameleonRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1214-
self.vqmodel = ChameleonVQVAE(config.vq_config)
1214+
self.vqmodel = ChameleonVQVAE._from_config(config.vq_config)
12151215
self.gradient_checkpointing = False
12161216

12171217
# Initialize weights and apply final processing

tests/models/qwen2_vl/test_modeling_qwen2_vl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ class Qwen2VLModelTest(ModelTesterMixin, GenerationTesterMixin, unittest.TestCas
227227
pipeline_model_mapping = {"image-text-to-text": Qwen2VLForConditionalGeneration}
228228
test_pruning = False
229229
test_head_masking = False
230+
_is_composite = True
230231

231232
def setUp(self):
232233
self.model_tester = Qwen2VLVisionText2TextModelTester(self)

tests/utils/test_modeling_utils.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
AutoModel,
3838
AutoModelForImageClassification,
3939
AutoModelForSequenceClassification,
40+
LlavaForConditionalGeneration,
4041
OwlViTForObjectDetection,
4142
PretrainedConfig,
4243
is_torch_available,
@@ -300,6 +301,7 @@ def test_local_files_only(self):
300301
TINY_BERT_FOR_TOKEN_CLASSIFICATION = "hf-internal-testing/tiny-bert-for-token-classification"
301302
TINY_MISTRAL = "hf-internal-testing/tiny-random-MistralForCausalLM"
302303
TINY_IMAGE_CLASSIF = "hf-internal-testing/tiny-random-SiglipForImageClassification"
304+
TINY_LLAVA = "hf-internal-testing/tiny-random-LlavaForConditionalGeneration"
303305

304306
LOG = logging.get_logger(__name__)
305307

@@ -460,6 +462,59 @@ def test_model_from_config_torch_dtype_str(self):
460462
with self.assertRaises(ValueError):
461463
model = AutoModel.from_pretrained(TINY_T5, torch_dtype="int64")
462464

465+
def test_model_from_config_torch_dtype_composite(self):
466+
"""
467+
Test that from_pretrained works with torch_dtype being as a dict per each sub-config in composite config
468+
"""
469+
# should be able to set torch_dtype as a simple string and the model loads it correctly
470+
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float32")
471+
self.assertEqual(model.language_model.dtype, torch.float32)
472+
self.assertEqual(model.vision_tower.dtype, torch.float32)
473+
474+
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="float16")
475+
self.assertEqual(model.language_model.dtype, torch.float16)
476+
self.assertEqual(model.vision_tower.dtype, torch.float16)
477+
478+
# should be able to set torch_dtype as a dict for each sub-config
479+
model = LlavaForConditionalGeneration.from_pretrained(
480+
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "float16", "": "bfloat16"}
481+
)
482+
self.assertEqual(model.language_model.dtype, torch.float32)
483+
self.assertEqual(model.vision_tower.dtype, torch.float16)
484+
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
485+
486+
# should be able to set the values as torch.dtype (not str)
487+
model = LlavaForConditionalGeneration.from_pretrained(
488+
TINY_LLAVA, torch_dtype={"text_config": torch.float32, "vision_config": torch.float16, "": torch.bfloat16}
489+
)
490+
self.assertEqual(model.language_model.dtype, torch.float32)
491+
self.assertEqual(model.vision_tower.dtype, torch.float16)
492+
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.bfloat16)
493+
494+
# should be able to set the values in configs directly and pass it to `from_pretrained`
495+
config = copy.deepcopy(model.config)
496+
config.text_config.torch_dtype = torch.float32
497+
config.vision_config.torch_dtype = torch.bfloat16
498+
config.torch_dtype = torch.float16
499+
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
500+
self.assertEqual(model.language_model.dtype, torch.float32)
501+
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
502+
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float16)
503+
504+
# but if the model has `_keep_in_fp32_modules` then those modules should be in fp32 no matter what
505+
LlavaForConditionalGeneration._keep_in_fp32_modules = ["multi_modal_projector"]
506+
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, config=config, torch_dtype="auto")
507+
self.assertEqual(model.language_model.dtype, torch.float32)
508+
self.assertEqual(model.vision_tower.dtype, torch.bfloat16)
509+
self.assertEqual(model.multi_modal_projector.linear_1.weight.dtype, torch.float32)
510+
511+
# torch.set_default_dtype() supports only float dtypes, so will fail with non-float type
512+
with self.assertRaises(ValueError):
513+
model = LlavaForConditionalGeneration.from_pretrained(TINY_LLAVA, torch_dtype="int64")
514+
model = LlavaForConditionalGeneration.from_pretrained(
515+
TINY_LLAVA, torch_dtype={"text_config": "float32", "vision_config": "int64", "": "float16"}
516+
)
517+
463518
@require_torch
464519
def test_model_from_pretrained_meta_device(self):
465520
def is_on_meta(model_id, dtype):

0 commit comments

Comments
 (0)