Skip to content

Commit cad7b3f

Browse files
authored
Update to Trainer.train to Allow Override Dataset (#142)
This update adds a new parameter '--training-dataset' that can be optionally set at the deepspeed call to pass in an alternate dataset for training. Adds a new constant for the databricks 15k dataset. Updates various Trainer functions to allow for a path override for the dataset.
1 parent 5f9bfba commit cad7b3f

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

training/consts.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"EleutherAI/pythia-12b",
66
"EleutherAI/gpt-j-6B",
77
]
8+
DEFAULT_TRAINING_DATASET = "databricks/databricks-dolly-15k"
89
INTRO_BLURB = (
910
"Below is an instruction that describes a task. Write a response that appropriately completes the request."
1011
)
@@ -71,4 +72,4 @@
7172
instruction_key=INSTRUCTION_KEY,
7273
instruction="{instruction}",
7374
response_key=RESPONSE_KEY,
74-
)
75+
)

training/trainer.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
END_KEY,
3939
INSTRUCTION_KEY,
4040
RESPONSE_KEY_NL,
41+
DEFAULT_TRAINING_DATASET,
4142
)
4243

4344
logger = logging.getLogger(__name__)
@@ -84,7 +85,7 @@ def preprocess_batch(batch: Dict[str, List], tokenizer: AutoTokenizer, max_lengt
8485
)
8586

8687

87-
def load_training_dataset(path_or_dataset: str = "databricks/databricks-dolly-15k") -> Dataset:
88+
def load_training_dataset(path_or_dataset: str = DEFAULT_TRAINING_DATASET) -> Dataset:
8889
logger.info(f"Loading dataset from {path_or_dataset}")
8990
dataset = load_dataset(path_or_dataset)["train"]
9091
logger.info("Found %d rows", dataset.num_rows)
@@ -144,7 +145,7 @@ def get_model_tokenizer(
144145
return model, tokenizer
145146

146147

147-
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_SEED) -> Dataset:
148+
def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_SEED, training_dataset: str = DEFAULT_TRAINING_DATASET) -> Dataset:
148149
"""Loads the training dataset and tokenizes it so it is ready for training.
149150
150151
Args:
@@ -155,7 +156,7 @@ def preprocess_dataset(tokenizer: AutoTokenizer, max_length: int, seed=DEFAULT_S
155156
Dataset: HuggingFace dataset
156157
"""
157158

158-
dataset = load_training_dataset()
159+
dataset = load_training_dataset(training_dataset)
159160

160161
logger.info("Preprocessing dataset")
161162
_preprocessing_function = partial(preprocess_batch, max_length=max_length, tokenizer=tokenizer)
@@ -198,6 +199,7 @@ def train(
198199
test_size: Union[float, int],
199200
save_total_limit: int,
200201
warmup_steps: int,
202+
training_dataset: str = DEFAULT_TRAINING_DATASET,
201203
):
202204
set_seed(seed)
203205

@@ -219,7 +221,7 @@ def train(
219221
max_length = 1024
220222
logger.info(f"Using default max length: {max_length}")
221223

222-
processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed)
224+
processed_dataset = preprocess_dataset(tokenizer=tokenizer, max_length=max_length, seed=seed, training_dataset=training_dataset)
223225

224226
split_dataset = processed_dataset.train_test_split(test_size=test_size, seed=seed)
225227

@@ -301,6 +303,7 @@ def train(
301303
@click.option("--lr", type=float, default=1e-5, help="Learning rate to use for training.")
302304
@click.option("--seed", type=int, default=DEFAULT_SEED, help="Seed to use for training.")
303305
@click.option("--deepspeed", type=str, default=None, help="Path to deepspeed config file.")
306+
@click.option("--training-dataset", type=str, default=DEFAULT_TRAINING_DATASET, help="Path to dataset for training")
304307
@click.option(
305308
"--gradient-checkpointing/--no-gradient-checkpointing",
306309
is_flag=True,

0 commit comments

Comments
 (0)