File tree Expand file tree Collapse file tree 2 files changed +45
-0
lines changed
source_transformation/torchtune Expand file tree Collapse file tree 2 files changed +45
-0
lines changed Original file line number Diff line number Diff line change @@ -883,8 +883,12 @@ def _load_llama_model(
883883def _get_source_transforms ( # noqa
884884 modelname : str , dtype_override : Optional [DType ], args
885885) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
886+ is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS
886887 transforms = []
887888
889+ if is_torchtune_model :
890+ transforms .append (replace_fusion_embeddings_with_nn_embedding )
891+
888892 if args .use_spin_quant :
889893 if args .use_spin_quant == "cuda" :
890894 from .source_transformation .spin_quant import (
@@ -971,4 +975,6 @@ def _get_source_transforms( # noqa
971975 transforms .append (replace_sdpa_with_simple_sdpa )
972976 transforms .append (replace_kv_cache_with_coreml_kv_cache )
973977
978+ print (f"Performing the following transforms: { [transform .__name__ for transform in transforms ]} " )
979+
974980 return transforms
Original file line number Diff line number Diff line change 1+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2+ # All rights reserved.
3+ #
4+ # This source code is licensed under the BSD-style license found in the
5+ # LICENSE file in the root directory of this source tree.
6+
7+ import logging
8+
9+ import torch
10+ from torchtune .modules .model_fusion ._fusion import FusionEmbedding
11+
12+
13+ def _replace_fusion_embeddings_with_nn_embedding (module : torch .nn .Module ) -> None :
14+ """
15+ Replace TorchTune's FusionEmbedding with nn.Embedding. This is because
16+ the FusionEmbedding is meant for efficient training and bears no
17+ effect on inference. This is better since we get to avoid some of the
18+ potentially missing torch ops in the FusionEmbedding such as
19+ masked_select and masked_scatter.
20+ """
21+
22+ for name , child in module .named_children ():
23+ if isinstance (child , FusionEmbedding ):
24+ setattr (
25+ module ,
26+ name ,
27+ torch .nn .Embedding (
28+ child .embedding .num_embeddings + child .fusion_embedding .num_embeddings ,
29+ child .dim ,
30+ )
31+ )
32+ else :
33+ _replace_fusion_embeddings_with_nn_embedding (child )
34+
35+ def replace_fusion_embeddings_with_nn_embedding (module : torch .nn .Module ) -> torch .nn .Module :
36+ logging .info ("Replacing fusion embeddings with nn.embeddings." )
37+ _replace_fusion_embeddings_with_nn_embedding (module )
38+ return module
39+
You can’t perform that action at this time.
0 commit comments