Skip to content

Commit d19e1d3

Browse files
adding the argument to load_dataset function call
1 parent d5bea15 commit d19e1d3

File tree

6 files changed

+27
-17
lines changed

6 files changed

+27
-17
lines changed

examples/controlnet/train_controlnet_flax.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
521521
args.dataset_config_name,
522522
cache_dir=args.cache_dir,
523523
streaming=args.streaming,
524+
trust_remote_code=args.trust_remote_code
524525
)
525526
else:
526527
if args.train_data_dir is not None:
@@ -532,6 +533,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
532533
dataset = load_dataset(
533534
args.train_data_dir,
534535
cache_dir=args.cache_dir,
536+
trust_remote_code=args.trust_remote_code
535537
)
536538
# See more about loading custom images at
537539
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

examples/controlnet/train_controlnet_flux.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,29 @@
1313
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1414
# See the License for the specific language governing permissions and
1515

16+
import accelerate
1617
import argparse
1718
import copy
1819
import functools
1920
import logging
2021
import math
22+
import numpy as np
2123
import os
2224
import random
2325
import shutil
24-
from contextlib import nullcontext
25-
from pathlib import Path
26-
27-
import accelerate
28-
import numpy as np
2926
import torch
3027
import torch.nn.functional as F
3128
import torch.utils.checkpoint
3229
import transformers
30+
from PIL import Image
3331
from accelerate import Accelerator
3432
from accelerate.logging import get_logger
3533
from accelerate.utils import DistributedType, ProjectConfiguration, set_seed
34+
from contextlib import nullcontext
3635
from datasets import load_dataset
3736
from huggingface_hub import create_repo, upload_folder
3837
from packaging import version
39-
from PIL import Image
38+
from pathlib import Path
4039
from torchvision import transforms
4140
from tqdm.auto import tqdm
4241
from transformers import (
@@ -60,7 +59,6 @@
6059
from diffusers.utils.import_utils import is_torch_npu_available, is_xformers_available
6160
from diffusers.utils.torch_utils import is_compiled_module
6261

63-
6462
if is_wandb_available():
6563
import wandb
6664

@@ -73,7 +71,7 @@
7371

7472

7573
def log_validation(
76-
vae, flux_transformer, flux_controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
74+
vae, flux_transformer, flux_controlnet, args, accelerator, weight_dtype, step, is_final_validation=False
7775
):
7876
logger.info("Running validation... ")
7977

@@ -266,7 +264,7 @@ def parse_args(input_args=None):
266264
type=str,
267265
default=None,
268266
help="Path to pretrained controlnet model or model identifier from huggingface.co/models."
269-
" If not specified controlnet weights are initialized from unet.",
267+
" If not specified controlnet weights are initialized from unet.",
270268
)
271269
parser.add_argument(
272270
"--variant",
@@ -668,11 +666,11 @@ def parse_args(input_args=None):
668666
raise ValueError("`--validation_prompt` must be set if `--validation_image` is set")
669667

670668
if (
671-
args.validation_image is not None
672-
and args.validation_prompt is not None
673-
and len(args.validation_image) != 1
674-
and len(args.validation_prompt) != 1
675-
and len(args.validation_image) != len(args.validation_prompt)
669+
args.validation_image is not None
670+
and args.validation_prompt is not None
671+
and len(args.validation_image) != 1
672+
and len(args.validation_prompt) != 1
673+
and len(args.validation_image) != len(args.validation_prompt)
676674
):
677675
raise ValueError(
678676
"Must provide either 1 `--validation_image`, 1 `--validation_prompt`,"
@@ -695,10 +693,12 @@ def get_train_dataset(args, accelerator):
695693
args.dataset_name,
696694
args.dataset_config_name,
697695
cache_dir=args.cache_dir,
696+
trust_remote_code=args.trust_remote_code
698697
)
699698
if args.jsonl_for_train is not None:
700699
# load from json
701-
dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir)
700+
dataset = load_dataset("json", data_files=args.jsonl_for_train, cache_dir=args.cache_dir,
701+
trust_remote_code=args.trust_remote_code)
702702
dataset = dataset.flatten_indices()
703703
# Preprocessing the datasets.
704704
# We need to tokenize inputs and targets.
@@ -1018,7 +1018,7 @@ def load_model_hook(models, input_dir):
10181018

10191019
if args.scale_lr:
10201020
args.learning_rate = (
1021-
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
1021+
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
10221022
)
10231023

10241024
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
@@ -1130,7 +1130,7 @@ def compute_embeddings(batch, proportion_empty_prompts, flux_controlnet_pipeline
11301130
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
11311131
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
11321132
num_training_steps_for_scheduler = (
1133-
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
1133+
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
11341134
)
11351135
else:
11361136
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes

examples/controlnet/train_controlnet_sd3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,12 +650,14 @@ def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, acce
650650
args.dataset_name,
651651
args.dataset_config_name,
652652
cache_dir=args.cache_dir,
653+
trust_remote_code=args.trust_remote_code
653654
)
654655
else:
655656
if args.train_data_dir is not None:
656657
dataset = load_dataset(
657658
args.train_data_dir,
658659
cache_dir=args.cache_dir,
660+
trust_remote_code=args.trust_remote_code
659661
)
660662
# See more about loading custom images at
661663
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

examples/controlnet/train_controlnet_sdxl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,12 +645,14 @@ def get_train_dataset(args, accelerator):
645645
args.dataset_config_name,
646646
cache_dir=args.cache_dir,
647647
data_dir=args.train_data_dir,
648+
trust_remote_code=args.trust_remote_code
648649
)
649650
else:
650651
if args.train_data_dir is not None:
651652
dataset = load_dataset(
652653
args.train_data_dir,
653654
cache_dir=args.cache_dir,
655+
trust_remote_code=args.trust_remote_code
654656
)
655657
# See more about loading custom images at
656658
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

examples/research_projects/pixart/train_pixart_controlnet_hf.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -771,6 +771,7 @@ def load_model_hook(models, input_dir):
771771
args.dataset_config_name,
772772
cache_dir=args.cache_dir,
773773
data_dir=args.train_data_dir,
774+
trust_remote_code=args.trust_remote_code
774775
)
775776
else:
776777
data_files = {}
@@ -780,6 +781,7 @@ def load_model_hook(models, input_dir):
780781
"imagefolder",
781782
data_files=data_files,
782783
cache_dir=args.cache_dir,
784+
trust_remote_code=args.trust_remote_code
783785
)
784786
# See more about loading custom images at
785787
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

examples/t2i_adapter/train_t2i_adapter_sdxl.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,12 +637,14 @@ def get_train_dataset(args, accelerator):
637637
args.dataset_name,
638638
args.dataset_config_name,
639639
cache_dir=args.cache_dir,
640+
trust_remote_code=args.trust_remote_code
640641
)
641642
else:
642643
if args.train_data_dir is not None:
643644
dataset = load_dataset(
644645
args.train_data_dir,
645646
cache_dir=args.cache_dir,
647+
trust_remote_code=args.trust_remote_code
646648
)
647649
# See more about loading custom images at
648650
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script

0 commit comments

Comments
 (0)