|
7 | 7 | # pyre-strict |
8 | 8 |
|
9 | 9 | from functools import partial |
10 | | -from typing import Any, Dict, Mapping, Optional |
| 10 | +from typing import Any |
11 | 11 |
|
12 | 12 | import torch |
13 | 13 | from executorch.extension.pybindings.aten_lib import ExecuTorchModule # @manual |
14 | 14 |
|
15 | 15 | from torch.nn import functional as F |
16 | 16 | from torch.utils.data import DataLoader, Dataset, DistributedSampler |
17 | | -from torchtune.data import InstructTemplate |
| 17 | +from torchtune.data import AlpacaToMessages |
18 | 18 | from torchtune.data._collate import padded_collate_sft |
| 19 | +from torchtune.datasets import PackedDataset, SFTDataset |
| 20 | +from torchtune.modules.tokenizers import ModelTokenizer |
19 | 21 | from tqdm import tqdm |
20 | 22 |
|
21 | 23 |
|
@@ -44,49 +46,25 @@ def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: |
44 | 46 | return self.loss(logits, labels) |
45 | 47 |
|
46 | 48 |
|
47 | | -class DatabricksDolly(InstructTemplate): |
| 49 | +def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset: |
48 | 50 | """ |
49 | | - Used for the Dolly dataset from Databricks. |
50 | | -
|
51 | | - https://huggingface.co/datasets/databricks/databricks-dolly-15k |
52 | | - """ |
53 | | - |
54 | | - template = "Instruction:\n{instruction}\n\nContext:\n{input}\n\nResponse: " |
55 | | - |
56 | | - @classmethod |
57 | | - def format( |
58 | | - cls, |
59 | | - sample: Mapping[str, Any], |
60 | | - column_map: Optional[Dict[str, str]], |
61 | | - ) -> str: |
62 | | - assert column_map is not None |
63 | | - instruction = sample[column_map["instruction"]] |
64 | | - input = sample[column_map["input"]] |
65 | | - return cls.template.format(instruction=instruction, input=input) |
66 | | - |
67 | | - |
68 | | -class PythonCodeInstructions(InstructTemplate): |
69 | | - """ |
70 | | - https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca |
| 51 | + Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca. |
71 | 52 | """ |
72 | | - |
73 | | - template = ( |
74 | | - "{prompt}\n\n" |
75 | | - "Instruction:\n{instruction}" |
76 | | - "\n\nContext:\n{input}\n\nResponse: " |
| 53 | + ds = SFTDataset( |
| 54 | + # pyre-ignore[6]: Incompatible parameter type |
| 55 | + model_transform=tokenizer, |
| 56 | + source="iamtarun/python_code_instructions_18k_alpaca", |
| 57 | + message_transform=AlpacaToMessages( |
| 58 | + train_on_input=False, |
| 59 | + ), |
| 60 | + # pyre-ignore[6]: Incompatible parameter type |
| 61 | + split="train", |
77 | 62 | ) |
78 | | - |
79 | | - @classmethod |
80 | | - def format( |
81 | | - cls, |
82 | | - sample: Mapping[str, Any], |
83 | | - column_map: Optional[Dict[str, str]], |
84 | | - ) -> str: |
85 | | - assert column_map is not None |
86 | | - instruction = sample[column_map["instruction"]] |
87 | | - input = sample[column_map["input"]] |
88 | | - prompt = sample[column_map["prompt"]] |
89 | | - return cls.template.format(instruction=instruction, input=input, prompt=prompt) |
| 63 | + if tokenizer.max_seq_len is None: |
| 64 | + raise ValueError( |
| 65 | + "PackedDataset requires a max_seq_len to be set on the tokenizer." |
| 66 | + ) |
| 67 | + return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False) |
90 | 68 |
|
91 | 69 |
|
92 | 70 | def update_function( |
|
0 commit comments