Skip to content

Commit b6044a0

Browse files
committed
Source transform torchtune fusion embedding to nn.embedding
1 parent e145bd1 commit b6044a0

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

examples/models/llama2/export_llama_lib.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,8 +883,12 @@ def _load_llama_model(
883883
def _get_source_transforms( # noqa
884884
modelname: str, dtype_override: Optional[DType], args
885885
) -> List[Callable[[torch.nn.Module], torch.nn.Module]]:
886+
is_torchtune_model = modelname in TORCHTUNE_DEFINED_MODELS
886887
transforms = []
887888

889+
if is_torchtune_model:
890+
transforms.append(replace_fusion_embeddings_with_nn_embedding)
891+
888892
if args.use_spin_quant:
889893
if args.use_spin_quant == "cuda":
890894
from .source_transformation.spin_quant import (
@@ -971,4 +975,6 @@ def _get_source_transforms( # noqa
971975
transforms.append(replace_sdpa_with_simple_sdpa)
972976
transforms.append(replace_kv_cache_with_coreml_kv_cache)
973977

978+
print(f"Performing the following transforms: {[transform.__name__ for transform in transforms]}")
979+
974980
return transforms
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from torchtune.modules.model_fusion._fusion import FusionEmbedding
11+
12+
13+
def _replace_fusion_embeddings_with_nn_embedding(module: torch.nn.Module) -> None:
14+
"""
15+
Replace TorchTune's FusionEmbedding with nn.Embedding. This is because
16+
the FusionEmbedding is meant for efficient training and bears no
17+
effect on inference. This is better since we get to avoid some of the
18+
potentially missing torch ops in the FusionEmbedding such as
19+
masked_select and masked_scatter.
20+
"""
21+
22+
for name, child in module.named_children():
23+
if isinstance(child, FusionEmbedding):
24+
setattr(
25+
module,
26+
name,
27+
torch.nn.Embedding(
28+
child.embedding.num_embeddings + child.fusion_embedding.num_embeddings,
29+
child.dim,
30+
)
31+
)
32+
else:
33+
_replace_fusion_embeddings_with_nn_embedding(child)
34+
35+
def replace_fusion_embeddings_with_nn_embedding(module: torch.nn.Module) -> torch.nn.Module:
36+
logging.info("Replacing fusion embeddings with nn.embeddings.")
37+
_replace_fusion_embeddings_with_nn_embedding(module)
38+
return module
39+

0 commit comments

Comments
 (0)