Skip to content

Commit e27ccdc

Browse files
author
Guang Yang
committed
Support MaskedLM from HuggingFace
1 parent e93ad5f commit e27ccdc

File tree

1 file changed

+106
-30
lines changed

1 file changed

+106
-30
lines changed

extension/export_util/export_hf_model.py

Lines changed: 106 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -10,42 +10,73 @@
1010
import torch
1111
import torch.export._trace
1212
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13-
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
13+
from executorch.exir import (
14+
EdgeCompileConfig,
15+
ExecutorchBackendConfig,
16+
to_edge,
17+
to_edge_transform_and_lower,
18+
)
1419
from torch.nn.attention import SDPBackend
15-
from transformers import AutoModelForCausalLM
20+
from transformers import (
21+
AutoConfig,
22+
AutoModelForCausalLM,
23+
AutoModelForMaskedLM,
24+
AutoTokenizer,
25+
)
1626
from transformers.generation.configuration_utils import GenerationConfig
1727
from transformers.integrations.executorch import convert_and_export_with_cache
1828
from transformers.modeling_utils import PreTrainedModel
1929

2030

21-
def main() -> None:
22-
parser = argparse.ArgumentParser()
23-
parser.add_argument(
24-
"-hfm",
25-
"--hf_model_repo",
26-
required=True,
27-
default=None,
28-
help="a valid huggingface model repo name",
31+
def _export_masked_lm(args):
32+
33+
device = "cpu"
34+
max_length = 64
35+
attn_implementation = "sdpa"
36+
37+
config = AutoConfig.from_pretrained(args.hf_model_repo)
38+
kwargs = {}
39+
if hasattr(config, "use_cache"):
40+
kwargs["use_cache"] = True
41+
42+
print(f"DEBUG: attn_implementation: {attn_implementation}")
43+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
44+
mask_token = tokenizer.mask_token
45+
print(f"Mask token: {mask_token}")
46+
inputs = tokenizer(
47+
f"The goal of life is {mask_token}.",
48+
return_tensors="pt",
49+
padding="max_length",
50+
max_length=max_length,
2951
)
30-
parser.add_argument(
31-
"-d",
32-
"--dtype",
33-
type=str,
34-
choices=["float32", "float16", "bfloat16"],
35-
default="float32",
36-
help="specify the dtype for loading the model",
52+
53+
model = AutoModelForMaskedLM.from_pretrained(
54+
args.hf_model_repo,
55+
device_map=device,
56+
attn_implementation=attn_implementation,
57+
**kwargs,
3758
)
38-
parser.add_argument(
39-
"-o",
40-
"--output_name",
41-
required=False,
42-
default=None,
43-
help="output name of the exported model",
59+
print(f"{model.config}")
60+
print(f"{model.generation_config}")
61+
62+
# pre-autograd export. eventually this will become torch.export
63+
exported_program = torch.export.export_for_training(
64+
model,
65+
args=(inputs["input_ids"],),
66+
kwargs={"attention_mask": inputs["attention_mask"]},
67+
strict=True,
4468
)
4569

46-
args = parser.parse_args()
70+
return model, to_edge_transform_and_lower(
71+
exported_program,
72+
partitioner=[XnnpackPartitioner()],
73+
compile_config=EdgeCompileConfig(
74+
_skip_dim_order=True,
75+
),
76+
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
4777

48-
# Configs to HF model
78+
79+
def _export_causal_lm(args):
4980
device = "cpu"
5081
# TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged
5182
dtype = getattr(torch, args.dtype)
@@ -106,11 +137,56 @@ def _get_constant_methods(model: PreTrainedModel):
106137
.to_backend(XnnpackPartitioner())
107138
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
108139
)
109-
out_name = args.output_name if args.output_name else model.config.model_type
110-
filename = os.path.join("./", f"{out_name}.pte")
111-
with open(filename, "wb") as f:
112-
prog.write_to_file(f)
113-
print(f"Saved exported program to {filename}")
140+
141+
return model, prog
142+
143+
144+
def main() -> None:
145+
parser = argparse.ArgumentParser()
146+
parser.add_argument(
147+
"-hfm",
148+
"--hf_model_repo",
149+
required=True,
150+
default=None,
151+
help="a valid huggingface model repo name",
152+
)
153+
parser.add_argument(
154+
"-d",
155+
"--dtype",
156+
type=str,
157+
choices=["float32", "float16", "bfloat16"],
158+
default="float32",
159+
help="specify the dtype for loading the model",
160+
)
161+
parser.add_argument(
162+
"-o",
163+
"--output_name",
164+
required=False,
165+
default=None,
166+
help="output name of the exported model",
167+
)
168+
parser.add_argument(
169+
"-lm",
170+
type=str,
171+
choices=["masked_lm", "causal_lm"],
172+
default="causal_lm",
173+
help="type of lm to load from huggingface",
174+
)
175+
176+
args = parser.parse_args()
177+
178+
if args.lm == "masked_lm":
179+
model, prog = _export_masked_lm(args)
180+
elif args.lm == "causal_lm":
181+
model, prog = _export_causal_lm(args)
182+
else:
183+
raise ValueError(f"Unsupported LM type {args.lm}")
184+
185+
out_name = args.output_name if args.output_name else model.config.model_type
186+
filename = os.path.join("./", f"{out_name}.pte")
187+
with open(filename, "wb") as f:
188+
prog.write_to_file(f)
189+
print(f"Saved exported program to {filename}")
114190

115191

116192
if __name__ == "__main__":

0 commit comments

Comments
 (0)