File tree Expand file tree Collapse file tree 1 file changed +6
-0
lines changed
Expand file tree Collapse file tree 1 file changed +6
-0
lines changed Original file line number Diff line number Diff line change @@ -883,8 +883,12 @@ def _load_llama_model(
883883def _get_source_transforms ( # noqa
884884 modelname : str , dtype_override : Optional [DType ], args
885885) -> List [Callable [[torch .nn .Module ], torch .nn .Module ]]:
886+ is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS
886887 transforms = []
887888
889+ if is_torchtune_model :
890+ transforms .append (replace_fusion_embeddings_with_nn_embedding )
891+
888892 if args .use_spin_quant :
889893 if args .use_spin_quant == "cuda" :
890894 from .source_transformation .spin_quant import (
@@ -971,4 +975,6 @@ def _get_source_transforms( # noqa
971975 transforms .append (replace_sdpa_with_simple_sdpa )
972976 transforms .append (replace_kv_cache_with_coreml_kv_cache )
973977
978+ print (f"Performing the following transforms: { [transform .__name__ for transform in transforms ]} " )
979+
974980 return transforms
You can’t perform that action at this time.
0 commit comments