|
10 | 10 | import torch |
11 | 11 | import torch.export._trace |
12 | 12 | from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner |
13 | | -from executorch.exir import EdgeCompileConfig, ExecutorchBackendConfig, to_edge |
| 13 | +from executorch.exir import ( |
| 14 | + EdgeCompileConfig, |
| 15 | + ExecutorchBackendConfig, |
| 16 | + to_edge, |
| 17 | + to_edge_transform_and_lower, |
| 18 | +) |
14 | 19 | from torch.nn.attention import SDPBackend |
15 | | -from transformers import AutoModelForCausalLM |
| 20 | +from transformers import ( |
| 21 | + AutoConfig, |
| 22 | + AutoModelForCausalLM, |
| 23 | + AutoModelForMaskedLM, |
| 24 | + AutoTokenizer, |
| 25 | +) |
16 | 26 | from transformers.generation.configuration_utils import GenerationConfig |
17 | 27 | from transformers.integrations.executorch import convert_and_export_with_cache |
18 | 28 | from transformers.modeling_utils import PreTrainedModel |
19 | 29 |
|
20 | 30 |
|
21 | | -def main() -> None: |
22 | | - parser = argparse.ArgumentParser() |
23 | | - parser.add_argument( |
24 | | - "-hfm", |
25 | | - "--hf_model_repo", |
26 | | - required=True, |
27 | | - default=None, |
28 | | - help="a valid huggingface model repo name", |
| 31 | +def _export_masked_lm(args): |
| 32 | + |
| 33 | + device = "cpu" |
| 34 | + max_length = 64 |
| 35 | + attn_implementation = "sdpa" |
| 36 | + |
| 37 | + config = AutoConfig.from_pretrained(args.hf_model_repo) |
| 38 | + kwargs = {} |
| 39 | + if hasattr(config, "use_cache"): |
| 40 | + kwargs["use_cache"] = True |
| 41 | + |
| 42 | + print(f"DEBUG: attn_implementation: {attn_implementation}") |
| 43 | + tokenizer = AutoTokenizer.from_pretrained(args.hf_model_repo) |
| 44 | + mask_token = tokenizer.mask_token |
| 45 | + print(f"Mask token: {mask_token}") |
| 46 | + inputs = tokenizer( |
| 47 | + f"The goal of life is {mask_token}.", |
| 48 | + return_tensors="pt", |
| 49 | + padding="max_length", |
| 50 | + max_length=max_length, |
29 | 51 | ) |
30 | | - parser.add_argument( |
31 | | - "-d", |
32 | | - "--dtype", |
33 | | - type=str, |
34 | | - choices=["float32", "float16", "bfloat16"], |
35 | | - default="float32", |
36 | | - help="specify the dtype for loading the model", |
| 52 | + |
| 53 | + model = AutoModelForMaskedLM.from_pretrained( |
| 54 | + args.hf_model_repo, |
| 55 | + device_map=device, |
| 56 | + attn_implementation=attn_implementation, |
| 57 | + **kwargs, |
37 | 58 | ) |
38 | | - parser.add_argument( |
39 | | - "-o", |
40 | | - "--output_name", |
41 | | - required=False, |
42 | | - default=None, |
43 | | - help="output name of the exported model", |
| 59 | + print(f"{model.config}") |
| 60 | + print(f"{model.generation_config}") |
| 61 | + |
| 62 | + # pre-autograd export. eventually this will become torch.export |
| 63 | + exported_program = torch.export.export_for_training( |
| 64 | + model, |
| 65 | + args=(inputs["input_ids"],), |
| 66 | + kwargs={"attention_mask": inputs["attention_mask"]}, |
| 67 | + strict=True, |
44 | 68 | ) |
45 | 69 |
|
46 | | - args = parser.parse_args() |
| 70 | + return model, to_edge_transform_and_lower( |
| 71 | + exported_program, |
| 72 | + partitioner=[XnnpackPartitioner()], |
| 73 | + compile_config=EdgeCompileConfig( |
| 74 | + _skip_dim_order=True, |
| 75 | + ), |
| 76 | + ).to_executorch(config=ExecutorchBackendConfig(extract_delegate_segments=False)) |
47 | 77 |
|
48 | | - # Configs to HF model |
| 78 | + |
| 79 | +def _export_causal_lm(args): |
49 | 80 | device = "cpu" |
50 | 81 | # TODO: remove getattr once https://github.com/huggingface/transformers/pull/33741 is merged |
51 | 82 | dtype = getattr(torch, args.dtype) |
@@ -106,11 +137,56 @@ def _get_constant_methods(model: PreTrainedModel): |
106 | 137 | .to_backend(XnnpackPartitioner()) |
107 | 138 | .to_executorch(ExecutorchBackendConfig(extract_delegate_segments=True)) |
108 | 139 | ) |
109 | | - out_name = args.output_name if args.output_name else model.config.model_type |
110 | | - filename = os.path.join("./", f"{out_name}.pte") |
111 | | - with open(filename, "wb") as f: |
112 | | - prog.write_to_file(f) |
113 | | - print(f"Saved exported program to {filename}") |
| 140 | + |
| 141 | + return model, prog |
| 142 | + |
| 143 | + |
| 144 | +def main() -> None: |
| 145 | + parser = argparse.ArgumentParser() |
| 146 | + parser.add_argument( |
| 147 | + "-hfm", |
| 148 | + "--hf_model_repo", |
| 149 | + required=True, |
| 150 | + default=None, |
| 151 | + help="a valid huggingface model repo name", |
| 152 | + ) |
| 153 | + parser.add_argument( |
| 154 | + "-d", |
| 155 | + "--dtype", |
| 156 | + type=str, |
| 157 | + choices=["float32", "float16", "bfloat16"], |
| 158 | + default="float32", |
| 159 | + help="specify the dtype for loading the model", |
| 160 | + ) |
| 161 | + parser.add_argument( |
| 162 | + "-o", |
| 163 | + "--output_name", |
| 164 | + required=False, |
| 165 | + default=None, |
| 166 | + help="output name of the exported model", |
| 167 | + ) |
| 168 | + parser.add_argument( |
| 169 | + "-lm", |
| 170 | + type=str, |
| 171 | + choices=["masked_lm", "causal_lm"], |
| 172 | + default="causal_lm", |
| 173 | + help="type of lm to load from huggingface", |
| 174 | + ) |
| 175 | + |
| 176 | + args = parser.parse_args() |
| 177 | + |
| 178 | + if args.lm == "masked_lm": |
| 179 | + model, prog = _export_masked_lm(args) |
| 180 | + elif args.lm == "causal_lm": |
| 181 | + model, prog = _export_causal_lm(args) |
| 182 | + else: |
| 183 | + raise ValueError(f"Unsupported LM type {args.lm}") |
| 184 | + |
| 185 | + out_name = args.output_name if args.output_name else model.config.model_type |
| 186 | + filename = os.path.join("./", f"{out_name}.pte") |
| 187 | + with open(filename, "wb") as f: |
| 188 | + prog.write_to_file(f) |
| 189 | + print(f"Saved exported program to {filename}") |
114 | 190 |
|
115 | 191 |
|
116 | 192 | if __name__ == "__main__": |
|
0 commit comments