3838 END_KEY ,
3939 INSTRUCTION_KEY ,
4040 RESPONSE_KEY_NL ,
41+ DEFAULT_TRAINING_DATASET ,
4142)
4243
4344logger = logging .getLogger (__name__ )
@@ -84,7 +85,7 @@ def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_lengt
8485 )
8586
8687
87- def load_training_dataset (path_or_dataset : str = "databricks/databricks-dolly-15k" ) -> Dataset :
88+ def load_training_dataset (path_or_dataset : str = DEFAULT_TRAINING_DATASET ) -> Dataset :
8889 logger .info (f"Loading dataset from { path_or_dataset } " )
8990 dataset = load_dataset (path_or_dataset )["train" ]
9091 logger .info ("Found %d rows" , dataset .num_rows )
@@ -144,7 +145,7 @@ def get_model_tokenizer(
144145 return model , tokenizer
145146
146147
147- def preprocess_dataset (tokenizer : AutoTokenizer , max_length : int , seed = DEFAULT_SEED ) -> Dataset :
148+ def preprocess_dataset (tokenizer : AutoTokenizer , max_length : int , seed = DEFAULT_SEED , training_dataset : str = DEFAULT_TRAINING_DATASET ) -> Dataset :
148149 """Loads the training dataset and tokenizes it so it is ready for training.
149150
150151 Args:
@@ -155,7 +156,7 @@ def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_S
155156 Dataset: HuggingFace dataset
156157 """
157158
158- dataset = load_training_dataset ()
159+ dataset = load_training_dataset (training_dataset )
159160
160161 logger .info ("Preprocessing dataset" )
161162 _preprocessing_function = partial (preprocess_batch , max_length = max_length , tokenizer = tokenizer )
@@ -198,6 +199,7 @@ def train(
198199 test_size : Union [float , int ],
199200 save_total_limit : int ,
200201 warmup_steps : int ,
202+ training_dataset : str = DEFAULT_TRAINING_DATASET ,
201203):
202204 set_seed (seed )
203205
@@ -219,7 +221,7 @@ def train(
219221 max_length = 1024
220222 logger .info (f"Using default max length: { max_length } " )
221223
222- processed_dataset = preprocess_dataset (tokenizer = tokenizer , max_length = max_length , seed = seed )
224+ processed_dataset = preprocess_dataset (tokenizer = tokenizer , max_length = max_length , seed = seed , training_dataset = training_dataset )
223225
224226 split_dataset = processed_dataset .train_test_split (test_size = test_size , seed = seed )
225227
@@ -301,6 +303,7 @@ def train(
301303@click .option ("--lr" , type = float , default = 1e-5 , help = "Learning rate to use for training." )
302304@click .option ("--seed" , type = int , default = DEFAULT_SEED , help = "Seed to use for training." )
303305@click .option ("--deepspeed" , type = str , default = None , help = "Path to deepspeed config file." )
306+ @click .option ("--training-dataset" , type = str , default = DEFAULT_TRAINING_DATASET , help = "Path to dataset for training" )
304307@click .option (
305308 "--gradient-checkpointing/--no-gradient-checkpointing" ,
306309 is_flag = True ,
0 commit comments