Skip to content

Commit 36c2dd1

Browse files
neuropilot-captainneuropilot-captain
andauthored
Support bert distilbert (#14138)
### Summary Add export scripts for supporting bert distilbert --------- Co-authored-by: neuropilot-captain <[email protected]>
1 parent 62a60d6 commit 36c2dd1

File tree

5 files changed

+229
-1
lines changed

5 files changed

+229
-1
lines changed

examples/mediatek/README.md

100644100755
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ source shell_scripts/export_llama.sh <model_name> <num_chunks> <prompt_num_token
7171
bash shell_scripts/export_oss.sh <model_name>
7272
```
7373
- Argument Options:
74-
- `model_name`: deeplabv3/edsr/inceptionv3/inceptionv4/mobilenetv2/mobilenetv3/resnet18/resnet50
74+
- `model_name`: deeplabv3/edsr/inceptionv3/inceptionv4/mobilenetv2/mobilenetv3/resnet18/resnet50/dcgan/wav2letter/vit_b_16/mobilebert/emformer_rnnt/bert/distilbert
7575

7676
# Runtime
7777
## Environment Setup

examples/mediatek/aot_utils/oss_utils/utils.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import os
8+
import random
89
from typing import Optional
910

1011
import torch
12+
import transformers
1113
from executorch import exir
1214
from executorch.backends.mediatek import (
1315
NeuropilotPartitioner,
@@ -42,6 +44,7 @@ def build_executorch_binary(
4244
quantized_model = convert_pt2e(annotated_model, fold_quantize=False)
4345
aten_dialect = torch.export.export(quantized_model, inputs, strict=True)
4446
else:
47+
print("Using float model...")
4548
aten_dialect = torch.export.export(model, inputs, strict=True)
4649

4750
from executorch.exir.program._program import to_edge_transform_and_lower
@@ -71,3 +74,58 @@ def make_output_dir(path: str):
7174
os.remove(os.path.join(path, f))
7275
os.removedirs(path)
7376
os.makedirs(path)
77+
78+
79+
def get_masked_language_model_dataset(dataset_path, tokenizer, data_size, shuffle=True):
80+
81+
def get_data_loader():
82+
class MaskedSentencesDataset(torch.utils.data.Dataset):
83+
def __init__(self, dataset_path, tokenizer, data_size) -> None:
84+
self.data_size = data_size
85+
self.dataset = self._get_val_dataset(dataset_path, data_size, tokenizer)
86+
87+
def _get_val_dataset(self, dataset_path, data_size, tokenizer):
88+
data_collator = transformers.DataCollatorForLanguageModeling(
89+
tokenizer=tokenizer
90+
)
91+
with open(dataset_path, "r") as f:
92+
texts = f.read().split("\n")
93+
texts = [
94+
text for text in random.choices(texts, k=2000) if len(text) > 1
95+
]
96+
dataset = data_collator([tokenizer(text) for text in texts])
97+
return dataset
98+
99+
def __getitem__(self, idx):
100+
return (
101+
self.dataset["input_ids"][idx].to(torch.int32),
102+
self.dataset["attention_mask"][idx].to(torch.float32),
103+
self.dataset["labels"][idx],
104+
)
105+
106+
def __len__(self):
107+
return self.data_size
108+
109+
dataset = MaskedSentencesDataset(dataset_path, tokenizer, data_size)
110+
return torch.utils.data.DataLoader(
111+
dataset,
112+
shuffle=shuffle,
113+
)
114+
115+
# prepare input data
116+
inputs, targets = [], []
117+
data_loader = get_data_loader()
118+
for data in data_loader:
119+
if len(inputs) >= data_size:
120+
break
121+
input_ids = data[0]
122+
attention_mask = data[1]
123+
target = data[2][0]
124+
indice = [i for i, x in enumerate(target) if x != -100]
125+
# continue if no mask annotated
126+
if len(indice) == 0:
127+
continue
128+
inputs.append((input_ids, attention_mask))
129+
targets.append(target)
130+
131+
return inputs, targets
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) MediaTek Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
import sys
10+
11+
if os.getcwd() not in sys.path:
12+
sys.path.append(os.getcwd())
13+
14+
from aot_utils.oss_utils.utils import (
15+
build_executorch_binary,
16+
get_masked_language_model_dataset,
17+
)
18+
from transformers import AutoModelForMaskedLM, AutoTokenizer
19+
20+
21+
def main(args):
22+
# ensure the working directory exist.
23+
os.makedirs(args.artifact, exist_ok=True)
24+
data_size = 100
25+
26+
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
27+
inputs, targets = get_masked_language_model_dataset(
28+
args.dataset, tokenizer, data_size
29+
)
30+
31+
# build pte
32+
module = AutoModelForMaskedLM.from_pretrained(
33+
"google-bert/bert-base-uncased"
34+
).eval()
35+
pte_filename = "bert_mtk"
36+
37+
build_executorch_binary(
38+
module,
39+
inputs[0],
40+
f"{args.artifact}/{pte_filename}",
41+
inputs,
42+
skip_op_name={"aten_embedding_default", "aten_where_self"},
43+
)
44+
45+
# save data to inference on device
46+
input_list_file = f"{args.artifact}/input_list.txt"
47+
with open(input_list_file, "w") as f:
48+
for i in range(len(inputs)):
49+
f.write(f"input_{i}_0.bin input_{i}_1.bin\n")
50+
for idx, data in enumerate(inputs):
51+
for i, d in enumerate(data):
52+
file_name = f"{args.artifact}/input_{idx}_{i}.bin"
53+
d.detach().numpy().tofile(file_name)
54+
for idx, data in enumerate(targets):
55+
file_name = f"{args.artifact}/golden_{idx}_0.bin"
56+
data.detach().numpy().tofile(file_name)
57+
58+
59+
if __name__ == "__main__":
60+
parser = argparse.ArgumentParser()
61+
parser.add_argument(
62+
"-a",
63+
"--artifact",
64+
help="path for storing generated artifacts and output by this example. Default ./bert",
65+
default="./bert",
66+
type=str,
67+
)
68+
parser.add_argument(
69+
"-d",
70+
"--dataset",
71+
help=(
72+
"path to the validation text. "
73+
"e.g. --dataset wikisent2.txt "
74+
"for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences"
75+
),
76+
default="wikisent2.txt",
77+
type=str,
78+
required=False,
79+
)
80+
81+
args = parser.parse_args()
82+
main(args)
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
# Copyright (c) MediaTek Inc.
2+
# All rights reserved
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import argparse
8+
import os
9+
import sys
10+
11+
if os.getcwd() not in sys.path:
12+
sys.path.append(os.getcwd())
13+
14+
from aot_utils.oss_utils.utils import (
15+
build_executorch_binary,
16+
get_masked_language_model_dataset,
17+
)
18+
from transformers import AutoModelForMaskedLM, AutoTokenizer
19+
20+
21+
def main(args):
22+
# ensure the working directory exist.
23+
os.makedirs(args.artifact, exist_ok=True)
24+
data_size = 100
25+
26+
tokenizer = AutoTokenizer.from_pretrained("distilbert/distilbert-base-uncased")
27+
inputs, targets = get_masked_language_model_dataset(
28+
args.dataset, tokenizer, data_size
29+
)
30+
31+
# build pte
32+
module = AutoModelForMaskedLM.from_pretrained(
33+
"distilbert/distilbert-base-uncased"
34+
).eval()
35+
pte_filename = "distilbert_mtk"
36+
37+
build_executorch_binary(
38+
module,
39+
inputs[0],
40+
f"{args.artifact}/{pte_filename}",
41+
inputs,
42+
skip_op_name={"aten_embedding_default", "aten_where_self"},
43+
)
44+
45+
# save data to inference on device
46+
input_list_file = f"{args.artifact}/input_list.txt"
47+
with open(input_list_file, "w") as f:
48+
for i in range(len(inputs)):
49+
f.write(f"input_{i}_0.bin input_{i}_1.bin\n")
50+
for idx, data in enumerate(inputs):
51+
for i, d in enumerate(data):
52+
file_name = f"{args.artifact}/input_{idx}_{i}.bin"
53+
d.detach().numpy().tofile(file_name)
54+
for idx, data in enumerate(targets):
55+
file_name = f"{args.artifact}/golden_{idx}_0.bin"
56+
data.detach().numpy().tofile(file_name)
57+
58+
59+
if __name__ == "__main__":
60+
parser = argparse.ArgumentParser()
61+
parser.add_argument(
62+
"-a",
63+
"--artifact",
64+
help="path for storing generated artifacts and output by this example. Default ./distilbert",
65+
default="./distilbert",
66+
type=str,
67+
)
68+
parser.add_argument(
69+
"-d",
70+
"--dataset",
71+
help=(
72+
"path to the validation text. "
73+
"e.g. --dataset wikisent2.txt "
74+
"for https://www.kaggle.com/datasets/mikeortman/wikipedia-sentences"
75+
),
76+
default="wikisent2.txt",
77+
type=str,
78+
required=False,
79+
)
80+
81+
args = parser.parse_args()
82+
main(args)

examples/mediatek/shell_scripts/export_oss.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,10 @@ then
4141
elif [ $model = "emformer_rnnt" ]
4242
then
4343
python3 model_export_scripts/emformer_rnnt.py
44+
elif [ $model = "bert" ]
45+
then
46+
python3 model_export_scripts/bert.py
47+
elif [ $model = "distilbert" ]
48+
then
49+
python3 model_export_scripts/distilbert.py
4450
fi

0 commit comments

Comments
 (0)