|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# This software may be used and distributed according to the terms of the Llama 3 Community License Agreement. |
| 3 | + |
| 4 | + |
| 5 | +import copy |
| 6 | +from datasets import load_dataset |
| 7 | +import itertools |
| 8 | + |
| 9 | +# check system prompt token seq or user prompt token seq is in the current token list |
| 10 | +def check_header(targets,seq): |
| 11 | + for i in range(len(seq)-3): |
| 12 | + if seq[i:i+3] in targets: |
| 13 | + return True |
| 14 | + return False |
| 15 | +def replace_target(target,seq): |
| 16 | + for i in range(len(seq)-3): |
| 17 | + if seq[i:i+3] == target: |
| 18 | + seq[i],seq[i+1],seq[i+2] = -100,-100,-100 |
| 19 | + return seq |
| 20 | +def tokenize_dialog(dialog, tokenizer): |
| 21 | + # If vocab size is above 128000, use the chat template to generate the tokens as it is from Llama 3 family models |
| 22 | + if tokenizer.vocab_size >= 128000: |
| 23 | + dialog_tokens = tokenizer.apply_chat_template(dialog) |
| 24 | + eot_indices = [i for i,n in enumerate(dialog_tokens) if n == 128009] |
| 25 | + labels = copy.copy(dialog_tokens) |
| 26 | + last_idx = 0 |
| 27 | + # system prompt header "<|start_header_id|>system<|end_header_id|>" has been tokenized to [128006, 9125, 128007] |
| 28 | + # user prompt header "<|start_header_id|>user<|end_header_id|>" has been tokenized to [128006, 882, 128007] |
| 29 | + prompt_header_seqs = [[128006, 9125, 128007],[128006, 882, 128007]] |
| 30 | + for n, idx in enumerate(eot_indices): |
| 31 | + current_seq = labels[last_idx:idx+1] |
| 32 | + if check_header(prompt_header_seqs,current_seq): |
| 33 | + # found prompt header, indicating that this seq should be masked |
| 34 | + labels[last_idx:idx+1] = [-100] * (idx-last_idx+1) |
| 35 | + else: |
| 36 | + last_idx = idx |
| 37 | + # Lastly mask all the assistant header prompt <|start_header_id|>assistant<|end_header_id|>, which has been tokenized to [128006, 78191, 128007] |
| 38 | + assistant_header_seq = [128006, 78191, 128007] |
| 39 | + labels = replace_target(assistant_header_seq,labels) |
| 40 | + dialog_tokens = [dialog_tokens] |
| 41 | + labels_tokens = [labels] |
| 42 | + else: |
| 43 | + raise Exception("This raft_dataset only supports Llama 3 family models, please make sure the tokenizer is from Llama 3 family models.") |
| 44 | + |
| 45 | + combined_tokens = { |
| 46 | + "input_ids": list(itertools.chain(*(t for t in dialog_tokens))), |
| 47 | + "labels": list(itertools.chain(*(t for t in labels_tokens))), |
| 48 | + } |
| 49 | + |
| 50 | + return dict(combined_tokens, attention_mask=[1]*len(combined_tokens["input_ids"])) |
| 51 | +def raft_tokenize(q_a_pair, tokenizer): |
| 52 | + end_tag = "</DOCUMENT>" |
| 53 | + # find the last end_tag in the instruction, the rest is the question |
| 54 | + try: |
| 55 | + index =q_a_pair["instruction"].rindex(end_tag)+len(end_tag) |
| 56 | + except ValueError: |
| 57 | + print(q_a_pair["instruction"]) |
| 58 | + raise Exception("The instruction does not contain the end tag <\/DOCUMENT>") |
| 59 | + # all the lines after end_tag are the question |
| 60 | + question = q_a_pair["instruction"][index:].strip() |
| 61 | + # all the lines before end_tag are the context |
| 62 | + documents = q_a_pair["instruction"][:index].strip() |
| 63 | + # output is the label |
| 64 | + answer = q_a_pair["output"] |
| 65 | + system_prompt = "You are a helpful chatbot who can provide an answer to every questions from the user given a relevant context." |
| 66 | + user_prompt = """ |
| 67 | + Question: {question}\nContext: {context}\n |
| 68 | + Answer this question using the information given by multiple documents in the context above. Here are the things to pay attention to: |
| 69 | + - The context contains many documents, each document starts with <DOCUMENT> and ends </DOCUMENT>. |
| 70 | + - First provide step-by-step reasoning on how to answer the question. |
| 71 | + - In the reasoning, if you need to copy paste some sentences from the context, include them in ##begin_quote## and ##end_quote##. This would mean that things outside of ##begin_quote## and ##end_quote## are not directly copy paste from the context. |
| 72 | + - End your response with final answer in the form <ANSWER>: $answer, the answer should less than 60 words. |
| 73 | + You MUST begin your final answer with the tag "<ANSWER>:". |
| 74 | + """.format(question=question, context=documents) |
| 75 | + |
| 76 | + chat = [ |
| 77 | + {"role": "system", "content": system_prompt}, |
| 78 | + {"role": "user", "content": user_prompt}, |
| 79 | + {"role": "assistant", "content": answer} |
| 80 | + ] |
| 81 | + return tokenize_dialog(chat, tokenizer) |
| 82 | + |
| 83 | + |
| 84 | +def get_custom_dataset(dataset_config, tokenizer, split, split_ratio=0.9): |
| 85 | + # load_dataset will return DatasetDict that contains all the data in the train set |
| 86 | + dataset_dict = load_dataset('json', data_files=dataset_config.data_path) |
| 87 | + dataset = dataset_dict['train'] |
| 88 | + dataset = dataset.train_test_split(test_size=1-split_ratio, shuffle=True, seed=42) |
| 89 | + |
| 90 | + dataset = dataset[split].map(lambda sample: { |
| 91 | + "instruction": sample["instruction"], |
| 92 | + "output": sample["cot_answer"], |
| 93 | + }, |
| 94 | + batched=True, |
| 95 | + ) |
| 96 | + dataset = dataset.map(lambda x: raft_tokenize(x, tokenizer)) |
| 97 | + return dataset |
0 commit comments