Skip to content

Commit 25c150c

Browse files
author
Guang Yang
committed
Support MaskedLM from HuggingFace
1 parent e93ad5f commit 25c150c

File tree

2 files changed

+197
-36
lines changed

2 files changed

+197
-36
lines changed

extension/export_util/export_hf_model.py

Lines changed: 181 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,33 @@
1010
import torch
1111
import torch.export._trace
1212
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
13-
from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge
13+
from executorch.exir import (
14+
EdgeCompileConfig,
15+
ExecutorchBackendConfig,
16+
to_edge,
17+
to_edge_transform_and_lower,
18+
)
1419
from torch.nn.attention import SDPBackend
15-
from transformers import AutoModelForCausalLM
20+
from transformers import (
21+
AutoConfig,
22+
AutoImageProcessor,
23+
AutoModelForCausalLM,
24+
AutoModelForDepthEstimation,
25+
AutoModelForMaskedLM,
26+
AutoModelForSemanticSegmentation,
27+
AutoTokenizer,
28+
)
1629
from transformers.generation.configuration_utils import GenerationConfig
1730
from transformers.integrations.executorch import convert_and_export_with_cache
1831
from transformers.modeling_utils import PreTrainedModel
1932

33+
from .task_registry import register_task, task_registry
2034

21-
def main() -> None:
22-
parser = argparse.ArgumentParser()
23-
parser.add_argument(
24-
"-hfm",
25-
"--hf_model_repo",
26-
required=True,
27-
default=None,
28-
help="a valid huggingface model repo name",
29-
)
30-
parser.add_argument(
31-
"-d",
32-
"--dtype",
33-
type=str,
34-
choices=["float32", "float16", "bfloat16"],
35-
default="float32",
36-
help="specify the dtype for loading the model",
37-
)
38-
parser.add_argument(
39-
"-o",
40-
"--output_name",
41-
required=False,
42-
default=None,
43-
help="output name of the exported model",
44-
)
4535

46-
args = parser.parse_args()
47-
48-
# Configs to HF model
36+
@register_task("causal_lm")
37+
def export_causal_lm(args):
4938
device = "cpu"
50-
# TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged
51-
dtype = getattr(torch, args.dtype)
39+
dtype = args.dtype
5240
batch_size = 1
5341
max_length = 123
5442
cache_implementation = "static"
@@ -106,11 +94,168 @@ def _get_constant_methods(model: PreTrainedModel):
10694
.to_backend(XnnpackPartitioner())
10795
.to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True))
10896
)
109-
out_name = args.output_name if args.output_name else model.config.model_type
110-
filename = os.path.join("./", f"{out_name}.pte")
111-
with open(filename, "wb") as f:
112-
prog.write_to_file(f)
113-
print(f"Saved exported program to {filename}")
97+
98+
return model, prog
99+
100+
101+
@register_task("masked_lm")
102+
def export_masked_lm(args):
103+
device = "cpu"
104+
max_length = 64
105+
attn_implementation = "sdpa"
106+
107+
config = AutoConfig.from_pretrained(args.hf_model_repo)
108+
kwargs = {}
109+
if hasattr(config, "use_cache"):
110+
kwargs["use_cache"] = True
111+
112+
print(f"DEBUG: attn_implementation: {attn_implementation}")
113+
tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo)
114+
mask_token = tokenizer.mask_token
115+
print(f"Mask token: {mask_token}")
116+
inputs = tokenizer(
117+
f"The goal of life is {mask_token}.",
118+
return_tensors="pt",
119+
padding="max_length",
120+
max_length=max_length,
121+
)
122+
123+
model = AutoModelForMaskedLM.from_pretrained(
124+
args.hf_model_repo,
125+
device_map=device,
126+
attn_implementation=attn_implementation,
127+
**kwargs,
128+
)
129+
print(f"{model.config}")
130+
print(f"{model.generation_config}")
131+
132+
# pre-autograd export. eventually this will become torch.export
133+
exported_program = torch.export.export_for_training(
134+
model,
135+
args=(inputs["input_ids"],),
136+
kwargs={"attention_mask": inputs["attention_mask"]},
137+
strict=True,
138+
)
139+
140+
return model, to_edge_transform_and_lower(
141+
exported_program,
142+
partitioner=[XnnpackPartitioner()],
143+
compile_config=EdgeCompileConfig(
144+
_skip_dim_order=True,
145+
),
146+
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
147+
148+
149+
@register_task("semantic_segmentation")
150+
def export_semantic_segmentation(args):
151+
import requests
152+
from PIL import Image
153+
154+
device = "cpu"
155+
model = AutoModelForSemanticSegmentation.from_pretrained(
156+
args.hf_model_repo,
157+
device_map=device,
158+
)
159+
image_processor = AutoImageProcessor.from_pretrained(
160+
args.hf_model_repo,
161+
device_map=device,
162+
)
163+
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
164+
image = Image.open(requests.get(image_url, stream=True).raw)
165+
inputs = image_processor(images=image, return_tensors="pt")
166+
167+
exported_program = torch.export.export_for_training(
168+
model,
169+
args=(inputs["pixel_values"],),
170+
strict=True,
171+
)
172+
173+
return model, to_edge_transform_and_lower(
174+
exported_program,
175+
partitioner=[XnnpackPartitioner()],
176+
compile_config=EdgeCompileConfig(
177+
_skip_dim_order=True,
178+
),
179+
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
180+
181+
182+
@register_task("depth_estimation")
183+
def export_depth_estimation(args):
184+
import requests
185+
from PIL import Image
186+
187+
device = "cpu"
188+
model = AutoModelForDepthEstimation.from_pretrained(
189+
args.hf_model_repo,
190+
device_map=device,
191+
)
192+
image_processor = AutoImageProcessor.from_pretrained(
193+
args.hf_model_repo,
194+
device_map=device,
195+
)
196+
image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/coco_sample.png"
197+
image = Image.open(requests.get(image_url, stream=True).raw)
198+
inputs = image_processor(images=image, return_tensors="pt")
199+
200+
exported_program = torch.export.export_for_training(
201+
model,
202+
args=(inputs["pixel_values"],),
203+
strict=True,
204+
)
205+
206+
return model, to_edge_transform_and_lower(
207+
exported_program,
208+
partitioner=[XnnpackPartitioner()],
209+
compile_config=EdgeCompileConfig(
210+
_skip_dim_order=True,
211+
),
212+
).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False))
213+
214+
215+
def main() -> None:
216+
parser = argparse.ArgumentParser()
217+
parser.add_argument(
218+
"-hfm",
219+
"--hf_model_repo",
220+
required=True,
221+
default=None,
222+
help="a valid huggingface model repo name",
223+
)
224+
parser.add_argument(
225+
"-d",
226+
"--dtype",
227+
type=str,
228+
choices=["float32", "float16", "bfloat16"],
229+
default="float32",
230+
help="specify the dtype for loading the model",
231+
)
232+
parser.add_argument(
233+
"-o",
234+
"--output_name",
235+
required=False,
236+
default=None,
237+
help="output name of the exported model",
238+
)
239+
parser.add_argument(
240+
"-t",
241+
"--task",
242+
type=str,
243+
choices=list(task_registry.keys()),
244+
default="causal_lm",
245+
help=f"type of task of the model to load from huggingface. supported tasks: {task_registry.keys()}",
246+
)
247+
248+
args = parser.parse_args()
249+
try:
250+
model, prog = task_registry[args.task](args)
251+
except AttributeError:
252+
raise ValueError(f"Unsupported task type {args.task}")
253+
254+
out_name = args.output_name if args.output_name else model.config.model_type
255+
filename = os.path.join("./", f"{out_name}.pte")
256+
with open(filename, "wb") as f:
257+
prog.write_to_file(f)
258+
print(f"Saved exported program to {filename}")
114259

115260

116261
if __name__ == "__main__":
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
8+
task_registry = {}
9+
10+
11+
def register_task(task_name):
12+
def decorator(func):
13+
task_registry[task_name] = func
14+
return func
15+
16+
return decorator

0 commit comments

Comments
 (0)