Skip to content

Commit 1b48e83

Browse files
author
Guang Yang
committed
Support MaskedLM from HuggingFace
1 parent e93ad5f commit 1b48e83

File tree

1 file changed

+95
-30
lines changed

1 file changed

+95
-30
lines changed

extension/export_util/export_hf_model.py

Lines changed: 95 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,62 @@
1010
import torch
1111
import torch.export._trace
1212
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13-
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
13+
from executorch.exir import (
14+
EdgeCompileConfig,
15+
ExecutorchBackendConfig,
16+
to_edge,
17+
to_edge_transform_and_lower,
18+
)
1419
from torch.nn.attention import SDPBackend
15-
from transformers import AutoModelForCausalLM
20+
from transformers import AutoModelForCausalLM, AutoModelForMaskedLM, AutoTokenizer
1621
from transformers.generation.configuration_utils import GenerationConfig
1722
from transformers.integrations.executorch import convert_and_export_with_cache
1823
from transformers.modeling_utils import PreTrainedModel
1924

2025

21-
def main() -> None:
22-
parser = argparse.ArgumentParser()
23-
parser.add_argument(
24-
"-hfm",
25-
"--hf_model_repo",
26-
required=True,
27-
default=None,
28-
help="a valid huggingface model repo name",
26+
def _export_masked_lm(args):
27+
28+
device = "cpu"
29+
attn_implementation = "sdpa"
30+
max_length = 64
31+
32+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
33+
mask_token = tokenizer.mask_token
34+
print(f"Mask token: {mask_token}")
35+
inputs = tokenizer(
36+
f"The goal of life is {mask_token}.",
37+
return_tensors="pt",
38+
padding="max_length",
39+
max_length=max_length,
2940
)
30-
parser.add_argument(
31-
"-d",
32-
"--dtype",
33-
type=str,
34-
choices=["float32", "float16", "bfloat16"],
35-
default="float32",
36-
help="specify the dtype for loading the model",
41+
42+
model = AutoModelForMaskedLM.from_pretrained(
43+
args.hf_model_repo,
44+
device_map=device,
45+
attn_implementation=attn_implementation,
46+
use_cache=True,
3747
)
38-
parser.add_argument(
39-
"-o",
40-
"--output_name",
41-
required=False,
42-
default=None,
43-
help="output name of the exported model",
48+
print(f"{model.config}")
49+
print(f"{model.generation_config}")
50+
51+
# pre-autograd export. eventually this will become torch.export
52+
exported_program = torch.export.export_for_training(
53+
model,
54+
args=(inputs["input_ids"],),
55+
kwargs={"attention_mask": inputs["attention_mask"]},
56+
strict=True,
4457
)
4558

46-
args = parser.parse_args()
59+
return model, to_edge_transform_and_lower(
60+
exported_program,
61+
partitioner=[XnnpackPartitioner()],
62+
compile_config=EdgeCompileConfig(
63+
_skip_dim_order=True,
64+
),
65+
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
66+
4767

48-
# Configs to HF model
68+
def _export_causal_lm(args):
4969
device = "cpu"
5070
# TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged
5171
dtype = getattr(torch, args.dtype)
@@ -106,11 +126,56 @@ def _get_constant_methods(model: PreTrainedModel):
106126
.to_backend(XnnpackPartitioner())
107127
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
108128
)
109-
out_name = args.output_name if args.output_name else model.config.model_type
110-
filename = os.path.join("./", f"{out_name}.pte")
111-
with open(filename, "wb") as f:
112-
prog.write_to_file(f)
113-
print(f"Saved exported program to {filename}")
129+
130+
return model, prog
131+
132+
133+
def main() -> None:
134+
parser = argparse.ArgumentParser()
135+
parser.add_argument(
136+
"-hfm",
137+
"--hf_model_repo",
138+
required=True,
139+
default=None,
140+
help="a valid huggingface model repo name",
141+
)
142+
parser.add_argument(
143+
"-d",
144+
"--dtype",
145+
type=str,
146+
choices=["float32", "float16", "bfloat16"],
147+
default="float32",
148+
help="specify the dtype for loading the model",
149+
)
150+
parser.add_argument(
151+
"-o",
152+
"--output_name",
153+
required=False,
154+
default=None,
155+
help="output name of the exported model",
156+
)
157+
parser.add_argument(
158+
"-lm",
159+
required=True,
160+
type=str,
161+
choices=["masked_lm", "causal_lm"],
162+
help="type of lm to load from huggingface",
163+
)
164+
165+
args = parser.parse_args()
166+
167+
if args.lm == "masked_lm":
168+
model, prog = _export_masked_lm(args)
169+
elif args.lm == "causal_lm":
170+
model, prog = _export_causal_lm(args)
171+
else:
172+
raise ValueError(f"Unsupported LM type {args.lm}")
173+
174+
out_name = args.output_name if args.output_name else model.config.model_type
175+
filename = os.path.join("./", f"{out_name}.pte")
176+
with open(filename, "wb") as f:
177+
prog.write_to_file(f)
178+
print(f"Saved exported program to {filename}")
114179

115180

116181
if __name__ == "__main__":

0 commit comments

Comments
 (0)