Skip to content

Commit a0e88e8

Browse files
committed
Source transformation from FusionEmbeddding to nn.embedding
1 parent e965dab commit a0e88e8

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,12 @@ def _load_llama_model(
883883
def _get_source_transforms( # noqa
884884
modelname: str, dtype_override: Optional[DType], args
885885
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
886+
is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS
886887
transforms = []
887888

889+
if is_torchtune_model:
890+
transforms.append(replace_fusion_embeddings_with_nn_embedding)
891+
888892
if args.use_spin_quant:
889893
if args.use_spin_quant == "cuda":
890894
from .source_transformation.spin_quant import (
@@ -971,4 +975,6 @@ def _get_source_transforms( # noqa
971975
transforms.append(replace_sdpa_with_simple_sdpa)
972976
transforms.append(replace_kv_cache_with_coreml_kv_cache)
973977

978+
print(f"Performing the following transforms: {[transform.__name__ for transform in transforms]}")
979+
974980
return transforms

0 commit comments

Comments
 (0)