1- from transformers import AutoModelForCausalLM , AutoTokenizer , LlamaTokenizer , LlamaTokenizerFast , GPTNeoXTokenizerFast
1+ from transformers import AutoModelForCausalLM , AutoTokenizer , LlamaTokenizer , LlamaTokenizerFast , GPTNeoXTokenizerFast , GPT2Tokenizer
22import fire
33from trl import SFTTrainer , DataCollatorForCompletionOnlyLM
44import transformers
@@ -77,21 +77,29 @@ def train(
7777 tokenizer = transformers .AutoTokenizer .from_pretrained (
7878 model_args .model_name_or_path ,
7979 cache_dir = train_args .cache_dir ,
80- padding_side = "right" ,
8180 use_fast = True
8281 )
82+
83+ # TODO: understand if we need to hardcode these here or just use defaults in model
8384 if isinstance (tokenizer , LlamaTokenizer ) or isinstance (tokenizer , LlamaTokenizerFast ):
8485 tokenizer .add_special_tokens ({
8586 "bos_token" : "<s>" ,
8687 "eos_token" : "</s>" ,
8788 "unk_token" : "<unk>" ,
8889 "pad_token" : "<pad>" ,
8990 })
90- elif isinstance (tokenizer , GPTNeoXTokenizerFast ):
91+ elif isinstance (tokenizer , GPTNeoXTokenizerFast ) or isinstance ( tokenizer , GPT2Tokenizer ) :
9192 tokenizer .add_special_tokens ({
9293 "pad_token" : "<pad>" ,
9394 })
94-
95+
96+ """TODO: near term - how response template ids are parsed out needs to be cleaned.
97+ The [2:] here applies if response template has \n prefix, it is needed to strip \n , otherwise template is not found.
98+ We will create issue to clean this out after we discuss data formats and collators we will support
99+ """
100+ response_template_ids = tokenizer .encode (data_args .response_template , add_special_tokens = False )[2 :]
101+ # TODO: This is actually max_seq_length and not model_max_length. we should not override model_max_length
102+ # as in current main. We need to change name of this parameter we expose to users.
95103 model_max_length = min (train_args .model_max_length , tokenizer .model_max_length )
96104 logger .info (f"Model max length { model_max_length } " )
97105 if train_args .model_max_length > tokenizer .model_max_length :
@@ -112,6 +120,8 @@ def train(
112120 logger .warning ("UNK token set to default, missing in tokenizer" )
113121 special_tokens_dict ["unk_token" ] = configs .DEFAULT_UNK_TOKEN
114122
123+ # TODO: lower priority but understand if resizing impacts inference quality and why its needed.
124+ # It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
115125 tokenizer_data_utils .tokenizer_and_embedding_resize (
116126 special_tokens_dict = special_tokens_dict ,
117127 tokenizer = tokenizer ,
@@ -120,7 +130,8 @@ def train(
120130
121131 # load the data by parsing JSON
122132 json_dataset = datasets .load_dataset ('json' , data_files = data_args .data_path )
123- logger .info (f"Dataset length is { len (json_dataset ['train' ])} " )
133+ formatted_dataset = json_dataset ['train' ].map (lambda example : {f"{ data_args .dataset_text_field } " : example [f"{ data_args .dataset_text_field } " ] + tokenizer .eos_token })
134+ logger .info (f"Dataset length is { len (formatted_dataset )} " )
124135
125136 aim_callback = get_aimstack_callback ()
126137 callbacks = [aim_callback ,PeftSavingCallback ()]
@@ -138,15 +149,14 @@ def train(
138149 if data_args .dataset_text_field is None :
139150 logger .error ("Error, dataset_text_field is None, needs to be set for training" )
140151 exit (- 1 )
141-
142- response_template_ids = tokenizer .encode (data_args .response_template , add_special_tokens = False )[2 :]
152+
143153 data_collator = DataCollatorForCompletionOnlyLM (response_template_ids , tokenizer = tokenizer , ignore_index = configs .IGNORE_INDEX )
144154 packing = False
145155
146156 trainer = SFTTrainer (
147157 model = model ,
148158 tokenizer = tokenizer ,
149- train_dataset = json_dataset [ 'train' ] ,
159+ train_dataset = formatted_dataset ,
150160 packing = packing ,
151161 data_collator = data_collator ,
152162 dataset_text_field = data_args .dataset_text_field ,
0 commit comments