Skip to content

Commit c8e8b75

Browse files
author
Guang Yang
committed
Support MaskedLM from HuggingFace
1 parent e93ad5f commit c8e8b75

File tree

2 files changed

+157
-32
lines changed

2 files changed

+157
-32
lines changed

extension/export_util/export_hf_model.py

Lines changed: 143 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,80 @@
1010
import torch
1111
import torch.export._trace
1212
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13-
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
13+
from executorch.exir import (
14+
EdgeCompileConfig,
15+
ExecutorchBackendConfig,
16+
to_edge,
17+
to_edge_transform_and_lower,
18+
)
1419
from torch.nn.attention import SDPBackend
15-
from transformers import AutoModelForCausalLM
20+
from transformers import (
21+
AutoConfig,
22+
AutoImageProcessor,
23+
AutoModelForCausalLM,
24+
AutoModelForMaskedLM,
25+
AutoModelForSemanticSegmentation,
26+
AutoTokenizer,
27+
)
1628
from transformers.generation.configuration_utils import GenerationConfig
1729
from transformers.integrations.executorch import convert_and_export_with_cache
1830
from transformers.modeling_utils import PreTrainedModel
1931

32+
from .task_registry import register_task, task_registry
2033

21-
def main() -> None:
22-
parser = argparse.ArgumentParser()
23-
parser.add_argument(
24-
"-hfm",
25-
"--hf_model_repo",
26-
required=True,
27-
default=None,
28-
help="a valid huggingface model repo name",
34+
35+
@register_task("masked_lm")
36+
def export_masked_lm(args):
37+
device = "cpu"
38+
max_length = 64
39+
attn_implementation = "sdpa"
40+
41+
config = AutoConfig.from_pretrained(args.hf_model_repo)
42+
kwargs = {}
43+
if hasattr(config, "use_cache"):
44+
kwargs["use_cache"] = True
45+
46+
print(f"DEBUG: attn_implementation: {attn_implementation}")
47+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
48+
mask_token = tokenizer.mask_token
49+
print(f"Mask token: {mask_token}")
50+
inputs = tokenizer(
51+
f"The goal of life is {mask_token}.",
52+
return_tensors="pt",
53+
padding="max_length",
54+
max_length=max_length,
2955
)
30-
parser.add_argument(
31-
"-d",
32-
"--dtype",
33-
type=str,
34-
choices=["float32", "float16", "bfloat16"],
35-
default="float32",
36-
help="specify the dtype for loading the model",
56+
57+
model = AutoModelForMaskedLM.from_pretrained(
58+
args.hf_model_repo,
59+
device_map=device,
60+
attn_implementation=attn_implementation,
61+
**kwargs,
3762
)
38-
parser.add_argument(
39-
"-o",
40-
"--output_name",
41-
required=False,
42-
default=None,
43-
help="output name of the exported model",
63+
print(f"{model.config}")
64+
print(f"{model.generation_config}")
65+
66+
# pre-autograd export. eventually this will become torch.export
67+
exported_program = torch.export.export_for_training(
68+
model,
69+
args=(inputs["input_ids"],),
70+
kwargs={"attention_mask": inputs["attention_mask"]},
71+
strict=True,
4472
)
4573

46-
args = parser.parse_args()
74+
return model, to_edge_transform_and_lower(
75+
exported_program,
76+
partitioner=[XnnpackPartitioner()],
77+
compile_config=EdgeCompileConfig(
78+
_skip_dim_order=True,
79+
),
80+
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
4781

48-
# Configs to HF model
82+
83+
@register_task("causal_lm")
84+
def export_causal_lm(args):
4985
device = "cpu"
50-
# TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged
51-
dtype = getattr(torch, args.dtype)
86+
dtype = args.dtype
5287
batch_size = 1
5388
max_length = 123
5489
cache_implementation = "static"
@@ -106,11 +141,87 @@ def _get_constant_methods(model: PreTrainedModel):
106141
.to_backend(XnnpackPartitioner())
107142
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
108143
)
109-
out_name = args.output_name if args.output_name else model.config.model_type
110-
filename = os.path.join("./", f"{out_name}.pte")
111-
with open(filename, "wb") as f:
112-
prog.write_to_file(f)
113-
print(f"Saved exported program to {filename}")
144+
145+
return model, prog
146+
147+
148+
@register_task("semantic_segmentation")
149+
def export_semantic_segmentation(args):
150+
import requests
151+
from PIL import Image
152+
153+
device = "cpu"
154+
model = AutoModelForSemanticSegmentation.from_pretrained(
155+
args.hf_model_repo,
156+
device_map=device,
157+
)
158+
image_processor = AutoImageProcessor.from_pretrained(
159+
args.hf_model_repo,
160+
device_map=device,
161+
)
162+
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
163+
image = Image.open(requests.get(image_url, stream=True).raw)
164+
inputs = image_processor(images=image, return_tensors="pt")
165+
166+
exported_program = torch.export.export_for_training(
167+
model,
168+
args=(inputs["pixel_values"],),
169+
strict=True,
170+
)
171+
172+
return model, to_edge_transform_and_lower(
173+
exported_program,
174+
partitioner=[XnnpackPartitioner()],
175+
compile_config=EdgeCompileConfig(
176+
_skip_dim_order=True,
177+
),
178+
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
179+
180+
181+
def main() -> None:
182+
parser = argparse.ArgumentParser()
183+
parser.add_argument(
184+
"-hfm",
185+
"--hf_model_repo",
186+
required=True,
187+
default=None,
188+
help="a valid huggingface model repo name",
189+
)
190+
parser.add_argument(
191+
"-d",
192+
"--dtype",
193+
type=str,
194+
choices=["float32", "float16", "bfloat16"],
195+
default="float32",
196+
help="specify the dtype for loading the model",
197+
)
198+
parser.add_argument(
199+
"-o",
200+
"--output_name",
201+
required=False,
202+
default=None,
203+
help="output name of the exported model",
204+
)
205+
parser.add_argument(
206+
"-t",
207+
"--task",
208+
type=str,
209+
choices=list(task_registry.keys()),
210+
default="causal_lm",
211+
help=f"type of task of the model to load from huggingface. supported tasks: {task_registry.keys()}",
212+
)
213+
214+
args = parser.parse_args()
215+
try:
216+
model, prog = task_registry[args.task](args)
217+
except AttributeError:
218+
raise ValueError(f"Unsupported task type {args.task}")
219+
220+
out_name = args.output_name if args.output_name else model.config.model_type
221+
filename = os.path.join("./", f"{out_name}.pte")
222+
with open(filename, "wb") as f:
223+
prog.write_to_file(f)
224+
print(f"Saved exported program to {filename}")
114225

115226

116227
if __name__ == "__main__":
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
task_registry = {}
9+
10+
def register_task(task_name):
11+
def decorator(func):
12+
task_registry[task_name] = func
13+
return func
14+
return decorator

0 commit comments

Comments
 (0)