diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 2e902db7ffc7..8d140f5ae6d1 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -571,9 +571,6 @@ def parse_args(input_args=None): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") - if args.dataset_name is not None and args.train_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") @@ -615,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, + data_dir=args.train_data_dir, ) else: if args.train_data_dir is not None: diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 877ca6135849..65981a3e07a4 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -598,9 +598,6 @@ def parse_args(input_args=None): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") - if args.dataset_name is not None and args.train_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") @@ -642,6 +639,7 @@ def get_train_dataset(args, accelerator): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, + data_dir=args.train_data_dir, ) else: if args.train_data_dir is not None: diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 2ca511c857ae..929047b1156b 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -481,7 +481,6 @@ def parse_args(input_args=None): # Sanity checks if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") @@ -819,9 +818,7 @@ def load_model_hook(models, input_dir): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir ) else: data_files = {}