Skip to content

Commit db38bcc

Browse files
authored
Remove deprecated InstructTemplate from llm_pte_finetuning example (pytorch#6557)
Remove deprecated InstructTemplate from llm_pte_finetuning example (pytorch#6557) Summary: See pytorch#6552 for context. Here, we remove the InstructTemplate classes and instead directly replace with a dataset builder that uses the necessary torchtune data components. The associated config is also updated. Reviewed By: larryliu0820 Differential Revision: D65168404 Pulled By: RdoubleA
1 parent f0463c4 commit db38bcc

File tree

2 files changed

+22
-51
lines changed

2 files changed

+22
-51
lines changed

examples/llm_pte_finetuning/phi3_alpaca_code_config.yaml

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,8 @@ tokenizer:
44
max_seq_len: 1024
55

66
dataset:
7-
_component_: torchtune.datasets.instruct_dataset
8-
template: papaya.toolkit.experimental.llm_pte_finetuning.utils.DatabricksDolly
9-
source: iamtarun/python_code_instructions_18k_alpaca
10-
split: train
11-
column_map:
12-
instruction: instruction
13-
prompt: prompt
14-
input: input
15-
output: output
7+
_component_: executorch.examples.llm_pte_finetuning.training_lib.python_code_instructions_alpaca
8+
169
seed: null
1710
shuffle: True
1811
batch_size: 1

examples/llm_pte_finetuning/training_lib.py

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
# pyre-strict
88

99
from functools import partial
10-
from typing import Any, Dict, Mapping, Optional
10+
from typing import Any
1111

1212
import torch
1313
from executorch.extension.pybindings.aten_lib import ExecuTorchModule # @manual
1414

1515
from torch.nn import functional as F
1616
from torch.utils.data import DataLoader, Dataset, DistributedSampler
17-
from torchtune.data import InstructTemplate
17+
from torchtune.data import AlpacaToMessages
1818
from torchtune.data._collate import padded_collate_sft
19+
from torchtune.datasets import PackedDataset, SFTDataset
20+
from torchtune.modules.tokenizers import ModelTokenizer
1921
from tqdm import tqdm
2022

2123

@@ -44,49 +46,25 @@ def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
4446
return self.loss(logits, labels)
4547

4648

47-
class DatabricksDolly(InstructTemplate):
49+
def python_code_instructions_alpaca(tokenizer: ModelTokenizer) -> PackedDataset:
4850
"""
49-
Used for the Dolly dataset from Databricks.
50-
51-
https://huggingface.co/datasets/databricks/databricks-dolly-15k
52-
"""
53-
54-
template = "Instruction:\n{instruction}\n\nContext:\n{input}\n\nResponse: "
55-
56-
@classmethod
57-
def format(
58-
cls,
59-
sample: Mapping[str, Any],
60-
column_map: Optional[Dict[str, str]],
61-
) -> str:
62-
assert column_map is not None
63-
instruction = sample[column_map["instruction"]]
64-
input = sample[column_map["input"]]
65-
return cls.template.format(instruction=instruction, input=input)
66-
67-
68-
class PythonCodeInstructions(InstructTemplate):
69-
"""
70-
https://huggingface.co/datasets/iamtarun/python_code_instructions_18k_alpaca
51+
Python code instruction-input-output pairs from iamtarun/python_code_instructions_18k_alpaca templated with Alpaca.
7152
"""
72-
73-
template = (
74-
"{prompt}\n\n"
75-
"Instruction:\n{instruction}"
76-
"\n\nContext:\n{input}\n\nResponse: "
53+
ds = SFTDataset(
54+
# pyre-ignore[6]: Incompatible parameter type
55+
model_transform=tokenizer,
56+
source="iamtarun/python_code_instructions_18k_alpaca",
57+
message_transform=AlpacaToMessages(
58+
train_on_input=False,
59+
),
60+
# pyre-ignore[6]: Incompatible parameter type
61+
split="train",
7762
)
78-
79-
@classmethod
80-
def format(
81-
cls,
82-
sample: Mapping[str, Any],
83-
column_map: Optional[Dict[str, str]],
84-
) -> str:
85-
assert column_map is not None
86-
instruction = sample[column_map["instruction"]]
87-
input = sample[column_map["input"]]
88-
prompt = sample[column_map["prompt"]]
89-
return cls.template.format(instruction=instruction, input=input, prompt=prompt)
63+
if tokenizer.max_seq_len is None:
64+
raise ValueError(
65+
"PackedDataset requires a max_seq_len to be set on the tokenizer."
66+
)
67+
return PackedDataset(ds, max_seq_len=tokenizer.max_seq_len, split_across_pack=False)
9068

9169

9270
def update_function(

0 commit comments

Comments
 (0)