|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 | 7 | import os
|
| 8 | +import random |
8 | 9 | from typing import Optional
|
9 | 10 |
|
10 | 11 | import torch
|
| 12 | +import transformers |
11 | 13 | from executorch import exir
|
12 | 14 | from executorch.backends.mediatek import (
|
13 | 15 | NeuropilotPartitioner,
|
@@ -42,6 +44,7 @@ def build_executorch_binary(
|
42 | 44 | quantized_model = convert_pt2e(annotated_model, fold_quantize=False)
|
43 | 45 | aten_dialect = torch.export.export(quantized_model, inputs, strict=True)
|
44 | 46 | else:
|
| 47 | + print("Using float model...") |
45 | 48 | aten_dialect = torch.export.export(model, inputs, strict=True)
|
46 | 49 |
|
47 | 50 | from executorch.exir.program._program import to_edge_transform_and_lower
|
@@ -71,3 +74,58 @@ def make_output_dir(path: str):
|
71 | 74 | os.remove(os.path.join(path, f))
|
72 | 75 | os.removedirs(path)
|
73 | 76 | os.makedirs(path)
|
| 77 | + |
| 78 | + |
| 79 | +def get_masked_language_model_dataset(dataset_path, tokenizer, data_size, shuffle=True): |
| 80 | + |
| 81 | + def get_data_loader(): |
| 82 | + class MaskedSentencesDataset(torch.utils.data.Dataset): |
| 83 | + def __init__(self, dataset_path, tokenizer, data_size) -> None: |
| 84 | + self.data_size = data_size |
| 85 | + self.dataset = self._get_val_dataset(dataset_path, data_size, tokenizer) |
| 86 | + |
| 87 | + def _get_val_dataset(self, dataset_path, data_size, tokenizer): |
| 88 | + data_collator = transformers.DataCollatorForLanguageModeling( |
| 89 | + tokenizer=tokenizer |
| 90 | + ) |
| 91 | + with open(dataset_path, "r") as f: |
| 92 | + texts = f.read().split("\n") |
| 93 | + texts = [ |
| 94 | + text for text in random.choices(texts, k=2000) if len(text) > 1 |
| 95 | + ] |
| 96 | + dataset = data_collator([tokenizer(text) for text in texts]) |
| 97 | + return dataset |
| 98 | + |
| 99 | + def __getitem__(self, idx): |
| 100 | + return ( |
| 101 | + self.dataset["input_ids"][idx].to(torch.int32), |
| 102 | + self.dataset["attention_mask"][idx].to(torch.float32), |
| 103 | + self.dataset["labels"][idx], |
| 104 | + ) |
| 105 | + |
| 106 | + def __len__(self): |
| 107 | + return self.data_size |
| 108 | + |
| 109 | + dataset = MaskedSentencesDataset(dataset_path, tokenizer, data_size) |
| 110 | + return torch.utils.data.DataLoader( |
| 111 | + dataset, |
| 112 | + shuffle=shuffle, |
| 113 | + ) |
| 114 | + |
| 115 | + # prepare input data |
| 116 | + inputs, targets = [], [] |
| 117 | + data_loader = get_data_loader() |
| 118 | + for data in data_loader: |
| 119 | + if len(inputs) >= data_size: |
| 120 | + break |
| 121 | + input_ids = data[0] |
| 122 | + attention_mask = data[1] |
| 123 | + target = data[2][0] |
| 124 | + indice = [i for i, x in enumerate(target) if x != -100] |
| 125 | + # continue if no mask annotated |
| 126 | + if len(indice) == 0: |
| 127 | + continue |
| 128 | + inputs.append((input_ids, attention_mask)) |
| 129 | + targets.append(target) |
| 130 | + |
| 131 | + return inputs, targets |
0 commit comments