diff --git a/diffsynth/trainers/dataset.py b/diffsynth/trainers/dataset.py new file mode 100644 index 00000000..1af2ceec --- /dev/null +++ b/diffsynth/trainers/dataset.py @@ -0,0 +1,345 @@ +from typing import Optional, Tuple, List, Dict, Any, Union, Set, Callable +import imageio, os, torch, warnings, torchvision, argparse, json +from peft import LoraConfig, inject_adapter_in_model +from PIL import Image +import pandas as pd +from tqdm import tqdm +from accelerate import Accelerator +from .dataset_key_configs import get_default_config, get_loader + +class BaseDataset(torch.utils.data.Dataset): + def __init__( + self, + base_path: str = None, + metadata_path: str = None, + max_pixels: int = 1920 * 1080, + height: int = None, + width: int = None, + height_division_factor: int = 16, + width_division_factor: int = 16, + default_key_model: str = "flux", + input_configs: str = None, + file_extensions: Tuple[str, ...] = ("jpg", "mp4"), + generated_target_key: str = "image", + repeat: int = 1, + ): + self.base_path = base_path if base_path is not None else "" + self.max_pixels = max_pixels + self.height = height + self.width = width + self.height_division_factor = height_division_factor + self.width_division_factor = width_division_factor + self.file_extensions = file_extensions + self.generated_target_key = generated_target_key + self.repeat = repeat + self.keyconfigs = self.parse_keyconfigs(input_configs, default_key_model) + data = self.load_meta(metadata_path) + self.data = self.preprocess(data) + + if height is not None and width is not None: + print("Fixed resolution. Setting `dynamic_resolution` to False.") + self.dynamic_resolution = False + else: + print("Dynamic resolution enabled.") + self.dynamic_resolution = True + + + def parse_keyconfigs(self, input_configs, default_key_model): + if input_configs is None: + print("No input configs provided, invalid dataset.") + return {} + default_configs = get_default_config(default_key_model) + keyconfigs = {} + for item in input_configs.split(","): + if ":" in item: + key, value = item.split(":", 1) + keyconfigs[key] = value + else: + keyconfigs[item] = default_configs[item] + print(f"Using dataset key configurations: {keyconfigs}") + return keyconfigs + + + def load_meta(self, metadata_path, base_path=None): + if metadata_path is None: + print("No metadata. Trying to generate it.") + metadata = self.generate_metadata(base_path) + metadata = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + elif metadata_path.endswith(".json"): + with open(metadata_path, "r") as f: + metadata = json.load(f) + elif metadata_path.endswith(".jsonl"): + metadata = [] + with open(metadata_path, 'r') as f: + for line in tqdm(f): + metadata.append(json.loads(line.strip())) + else: + metadata = pd.read_csv(metadata_path) + metadata = [metadata.iloc[i].to_dict() for i in range(len(metadata))] + print(f"successfully loaded {len(metadata)} metadata from {metadata_path}.") + return metadata + + + def generate_metadata(self, folder): + file_list, prompt_list = [], [] + file_set = set(os.listdir(folder)) + for file_name in file_set: + if "." not in file_name: + continue + file_ext_name = file_name.split(".")[-1].lower() + file_base_name = file_name[:-len(file_ext_name)-1] + if file_ext_name not in self.file_extensions: + continue + prompt_file_name = file_base_name + ".txt" + if prompt_file_name not in file_set: + continue + with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f: + prompt = f.read().strip() + file_list.append(file_name) + prompt_list.append(prompt) + metadata = pd.DataFrame() + metadata[self.generated_target_key] = file_list + metadata["prompt"] = prompt_list + return metadata + + + def convert_to_absolute_path(self, path): + if isinstance(path, list) or isinstance(path, tuple): + return [os.path.join(self.base_path, p) for p in path] + else: + return os.path.join(self.base_path, path) + + + def check_file_existence(self, path): + if isinstance(path, list) or isinstance(path, tuple): + for p in path: + assert os.path.exists(p), f"file {p} does not exist." + else: + assert os.path.exists(path), f"file {path} does not exist." + + + def preprocess(self, data): + required_keys = list(self.keyconfigs.keys()) + file_keys = [k for k in required_keys if self.keyconfigs[k] in ("image", "video", "tensor")] + new_data = [] + for cur_data in tqdm(data): + try: + # fetch all required keys + cur_data = {k: cur_data[k] for k in required_keys} + # convert file paths to absolute paths and check existence + for file_key in file_keys: + cur_data[file_key] = self.convert_to_absolute_path(cur_data[file_key]) + # self.check_file_existence(cur_data[file_key]) + # add to filtered data + new_data.append(cur_data) + except: + continue + print(f"get {len(new_data)} valid data from total {len(data)} metadata.") + return new_data + + + def crop_and_resize(self, image, target_height, target_width): + width, height = image.size + scale = max(target_width / width, target_height / height) + image = torchvision.transforms.functional.resize( + image, + (round(height*scale), round(width*scale)), + interpolation=torchvision.transforms.InterpolationMode.BILINEAR + ) + image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) + return image + + + def get_height_width(self, image): + if self.dynamic_resolution: + width, height = image.size + if width * height > self.max_pixels: + scale = (width * height / self.max_pixels) ** 0.5 + height, width = int(height / scale), int(width / scale) + height = height // self.height_division_factor * self.height_division_factor + width = width // self.width_division_factor * self.width_division_factor + else: + height, width = self.height, self.width + return height, width + + + def parse_image(self, image): + image = self.crop_and_resize(image, *self.get_height_width(image)) + return image + + + def parse_video(self, video): + for i in range(len(video)): + video[i] = self.crop_and_resize(video[i], *self.get_height_width(video[i])) + return video + + + def load_data(self, item, data_type): + loader = get_loader(data_type) + if isinstance(item, list) or isinstance(item, tuple): + return [loader(p) for p in item] + else: + return loader(item) + + + def type_parser(self, item, data_type): + if data_type in ("raw", "int", "float"): + return item + elif data_type == "image": + if isinstance(item, list) or isinstance(item, tuple): + return [self.parse_image(img) for img in item] + else: + return self.parse_image(item) + elif data_type == "video": + return self.parse_video(item) + elif data_type == "tensor": + # TODO: implement tensor parsing + return item + else: + return item + + + def __getitem__(self, data_id): + max_retries = 10 + while True: + data = self.data[data_id % len(self.data)].copy() + try: + for key in data.keys(): + data_type = self.keyconfigs.get(key, "raw") + item = self.load_data(data[key], data_type) + item = self.type_parser(item, data_type) + data[key] = item + return data + except: + warnings.warn(f"Error loading data with id {data_id}. Replacing with another data.") + data_id = torch.randint(0, len(self), (1,)).item() + max_retries -= 1 + if max_retries <= 0: + warnings.warn("Max retries reached. Returning None.") + return None + + + def __len__(self): + return len(self.data) * self.repeat + + +class ImageDataset(BaseDataset): + def __init__( + self, + base_path=None, metadata_path=None, + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + default_key_model="flux", input_configs="prompt:raw,image:image", + file_extensions=("jpg", "jpeg", "png", "webp"), + generated_target_key="image", repeat=1, + args=None, + ): + if args is not None: + base_path = args.dataset_base_path + metadata_path = args.dataset_metadata_path + height = args.height + width = args.width + max_pixels = args.max_pixels + input_configs = args.dataset_input_configs if args.dataset_input_configs else "prompt:raw,image:image" + repeat = args.dataset_repeat + + super().__init__( + base_path=base_path, + metadata_path=metadata_path, + max_pixels=max_pixels, + height=height, + width=width, + height_division_factor=height_division_factor, + width_division_factor=width_division_factor, + default_key_model=default_key_model, + input_configs=input_configs, + file_extensions=file_extensions, + generated_target_key=generated_target_key, + repeat=repeat, + ) + + def parse_image(self, image): + return self.crop_and_resize(image, *self.get_height_width(image)) + + +class VideoDataset(BaseDataset): + def __init__( + self, + base_path=None, metadata_path=None, + num_frames=81, + time_division_factor=4, time_division_remainder=1, + max_pixels=1920*1080, height=None, width=None, + height_division_factor=16, width_division_factor=16, + default_key_model="wan", input_configs="prompt:raw,video:video", + file_extensions=("jpg", "jpeg", "png", "webp", "mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), + generated_target_key="video", repeat=1, + args=None, + ): + if args is not None: + base_path = args.dataset_base_path + metadata_path = args.dataset_metadata_path + height = args.height + width = args.width + max_pixels = args.max_pixels + input_configs = args.dataset_input_configs if args.dataset_input_configs else "prompt:raw,video:video" + repeat = args.dataset_repeat + num_frames = args.num_frames + + self.num_frames = num_frames + self.time_division_factor = time_division_factor + self.time_division_remainder = time_division_remainder + super().__init__( + base_path=base_path, + metadata_path=metadata_path, + max_pixels=max_pixels, + height=height, + width=width, + height_division_factor=height_division_factor, + width_division_factor=width_division_factor, + default_key_model=default_key_model, + input_configs=input_configs, + file_extensions=file_extensions, + generated_target_key=generated_target_key, + repeat=repeat, + ) + + + def parse_video(self, video): + num_frames = self.get_num_frames(video) + video = video[:num_frames] + for i in range(len(video)): + video[i] = self.crop_and_resize(video[i], *self.get_height_width(video[i])) + return video + + + def parse_image(self, image): + image = self.crop_and_resize(image, *self.get_height_width(image)) + return [image] + + + def get_num_frames(self, video): + num_frames = self.num_frames + if len(video) < num_frames: + num_frames = len(video) + while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: + num_frames -= 1 + return num_frames + + +def general_dataset_parser(): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") + parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") + parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..") + parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") + parser.add_argument("--dataset_input_configs", type=str, default=None, help="Data file keys and data types in the metadata. Comma-separated.") + parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + + return parser + + +def video_dataset_parser(): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument("--num_frames", type=int, default=81, help="Number of frames to sample from each video.") + return parser diff --git a/diffsynth/trainers/dataset_key_configs.py b/diffsynth/trainers/dataset_key_configs.py new file mode 100644 index 00000000..790c1615 --- /dev/null +++ b/diffsynth/trainers/dataset_key_configs.py @@ -0,0 +1,171 @@ +from easydict import EasyDict +from PIL import Image +import imageio +import torch + + +""" +types: +- raw: raw info which does not need any processing, example: "string" +- int: integer number to represent a value, example: 123 +- float: floating point number to represent a value, example: 123.45 +- image: image file, postfix: "jpg", "jpeg", "png", "webp" +- video: video file, postfix: "mp4", "avi", "mov", "wmv", "mkv", "flv", "webm" +- tensor: pytorch tensor, postfix: "pt" +""" + +# shared dataset key configuration for all models +shared_config = EasyDict(__name__='Datatype Config: Base') +shared_config.prompt = 'raw' + + +# dataset key configuration for Flux.1-dev model +flux_config = EasyDict(__name__='Datatype Config: Flux.1-dev') +flux_config.update(shared_config) +flux_config.image = 'image' + +flux_config.kontext_images = 'image' + +flux_config.ipadapter_images = 'image' + +flux_config.eligen_entity_prompts = 'raw' +flux_config.eligen_entity_masks = 'image' + +flux_config.infinityou_id_image = 'image' +flux_config.infinityou_guidance = 'float' + +flux_config.step1x_reference_image = 'image' + +flux_config.nexus_gen_reference_image = 'image' + +flux_config.value_controller_inputs = 'float' + +flux_config.input_latents = 'tensor' + +flux_config.controlnet_image = 'image' +flux_config.controlnet_inpaint_mask = 'image' +flux_config.controlnet_processor_id = 'raw' + + +# dataset key configuration for qwen-image +qwen_image_config = EasyDict(__name__='Datatype Config: Qwen-Image') +qwen_image_config.update(shared_config) +qwen_image_config.image = 'image' + + +# dataset key configuration for Wan model +wan_config = EasyDict(__name__='Datatype Config: Wan') +wan_config.update(shared_config) +wan_config.video = 'video' + +wan_config.motion_bucket_id = 'int' + +wan_config.input_image = 'image' +wan_config.end_image = 'image' + +wan_config.control_video = 'video' + +wan_config.camera_control_direction = 'raw' +wan_config.camera_control_speed = 'float' + +wan_config.reference_image = 'image' + +wan_config.vace_video = 'video' +wan_config.vace_reference_image = 'image' + + +def get_default_config(model_name): + """ + Get the default dataset key configuration for the given model name. + :param model_name: Name of the model + :return: EasyDict containing the default dataset key configuration + """ + if model_name.lower() == 'flux': + return flux_config + elif model_name.lower() == 'qwen-image': + return qwen_image_config + elif model_name.lower() == 'wan': + return wan_config + else: + return shared_config + +def raw_loader(value): + """ + Load a raw value. + :param value: The raw value to load + :return: The loaded raw value + """ + return value + + +def int_loader(value): + """ + Load an integer value. + :param value: The integer value to load + :return: The loaded integer value + """ + return int(value) + + +def float_loader(value): + """ + Load a floating point value. + :param value: The floating point value to load + :return: The loaded floating point value + """ + return float(value) + + +def image_loader(file_path): + """ + Load an image file. + :param value: The image file path to load + :return: The loaded image + """ + return Image.open(file_path).convert('RGB') + + +def video_loader(file_path): + """ + Load a video file. + :param value: The video file path to load + :return: The loaded video + """ + reader = imageio.get_reader(file_path) + num_frames = int(reader.count_frames()) + frames = [] + for frame_id in range(num_frames): + frame = reader.get_data(frame_id) + frame = Image.fromarray(frame) + frames.append(frame) + reader.close() + return frames + +def tensor_loader(file_path): + """ + Load a PyTorch tensor file. + :param file_path: The tensor file path to load + :return: The loaded tensor + """ + return torch.load(file_path, map_location='cpu') + +def get_loader(data_type): + """ + Get the loader function for the given data type. + :param data_type: The data type to get the loader for + :return: The loader function + """ + if data_type == 'raw': + return raw_loader + elif data_type == 'int': + return int_loader + elif data_type == 'float': + return float_loader + elif data_type == 'image': + return image_loader + elif data_type == 'video': + return video_loader + elif data_type == 'tensor': + return tensor_loader + else: + raise ValueError(f"Unsupported data type: {data_type}") \ No newline at end of file diff --git a/diffsynth/trainers/utils.py b/diffsynth/trainers/utils.py index c478e920..ac6c79a5 100644 --- a/diffsynth/trainers/utils.py +++ b/diffsynth/trainers/utils.py @@ -4,315 +4,7 @@ import pandas as pd from tqdm import tqdm from accelerate import Accelerator - - - -class ImageDataset(torch.utils.data.Dataset): - def __init__( - self, - base_path=None, metadata_path=None, - max_pixels=1920*1080, height=None, width=None, - height_division_factor=16, width_division_factor=16, - data_file_keys=("image",), - image_file_extension=("jpg", "jpeg", "png", "webp"), - repeat=1, - args=None, - ): - if args is not None: - base_path = args.dataset_base_path - metadata_path = args.dataset_metadata_path - height = args.height - width = args.width - max_pixels = args.max_pixels - data_file_keys = args.data_file_keys.split(",") - repeat = args.dataset_repeat - - self.base_path = base_path - self.max_pixels = max_pixels - self.height = height - self.width = width - self.height_division_factor = height_division_factor - self.width_division_factor = width_division_factor - self.data_file_keys = data_file_keys - self.image_file_extension = image_file_extension - self.repeat = repeat - - if height is not None and width is not None: - print("Height and width are fixed. Setting `dynamic_resolution` to False.") - self.dynamic_resolution = False - elif height is None and width is None: - print("Height and width are none. Setting `dynamic_resolution` to True.") - self.dynamic_resolution = True - - if metadata_path is None: - print("No metadata. Trying to generate it.") - metadata = self.generate_metadata(base_path) - print(f"{len(metadata)} lines in metadata.") - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] - elif metadata_path.endswith(".json"): - with open(metadata_path, "r") as f: - metadata = json.load(f) - self.data = metadata - elif metadata_path.endswith(".jsonl"): - metadata = [] - with open(metadata_path, 'r') as f: - for line in tqdm(f): - metadata.append(json.loads(line.strip())) - self.data = metadata - else: - metadata = pd.read_csv(metadata_path) - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] - - - def generate_metadata(self, folder): - image_list, prompt_list = [], [] - file_set = set(os.listdir(folder)) - for file_name in file_set: - if "." not in file_name: - continue - file_ext_name = file_name.split(".")[-1].lower() - file_base_name = file_name[:-len(file_ext_name)-1] - if file_ext_name not in self.image_file_extension: - continue - prompt_file_name = file_base_name + ".txt" - if prompt_file_name not in file_set: - continue - with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f: - prompt = f.read().strip() - image_list.append(file_name) - prompt_list.append(prompt) - metadata = pd.DataFrame() - metadata["image"] = image_list - metadata["prompt"] = prompt_list - return metadata - - - def crop_and_resize(self, image, target_height, target_width): - width, height = image.size - scale = max(target_width / width, target_height / height) - image = torchvision.transforms.functional.resize( - image, - (round(height*scale), round(width*scale)), - interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) - return image - - - def get_height_width(self, image): - if self.dynamic_resolution: - width, height = image.size - if width * height > self.max_pixels: - scale = (width * height / self.max_pixels) ** 0.5 - height, width = int(height / scale), int(width / scale) - height = height // self.height_division_factor * self.height_division_factor - width = width // self.width_division_factor * self.width_division_factor - else: - height, width = self.height, self.width - return height, width - - - def load_image(self, file_path): - image = Image.open(file_path).convert("RGB") - image = self.crop_and_resize(image, *self.get_height_width(image)) - return image - - - def load_data(self, file_path): - return self.load_image(file_path) - - - def __getitem__(self, data_id): - data = self.data[data_id % len(self.data)].copy() - for key in self.data_file_keys: - if key in data: - if isinstance(data[key], list): - path = [os.path.join(self.base_path, p) for p in data[key]] - data[key] = [self.load_data(p) for p in path] - else: - path = os.path.join(self.base_path, data[key]) - data[key] = self.load_data(path) - if data[key] is None: - warnings.warn(f"cannot load file {data[key]}.") - return None - return data - - - def __len__(self): - return len(self.data) * self.repeat - - - -class VideoDataset(torch.utils.data.Dataset): - def __init__( - self, - base_path=None, metadata_path=None, - num_frames=81, - time_division_factor=4, time_division_remainder=1, - max_pixels=1920*1080, height=None, width=None, - height_division_factor=16, width_division_factor=16, - data_file_keys=("video",), - image_file_extension=("jpg", "jpeg", "png", "webp"), - video_file_extension=("mp4", "avi", "mov", "wmv", "mkv", "flv", "webm"), - repeat=1, - args=None, - ): - if args is not None: - base_path = args.dataset_base_path - metadata_path = args.dataset_metadata_path - height = args.height - width = args.width - max_pixels = args.max_pixels - num_frames = args.num_frames - data_file_keys = args.data_file_keys.split(",") - repeat = args.dataset_repeat - - self.base_path = base_path - self.num_frames = num_frames - self.time_division_factor = time_division_factor - self.time_division_remainder = time_division_remainder - self.max_pixels = max_pixels - self.height = height - self.width = width - self.height_division_factor = height_division_factor - self.width_division_factor = width_division_factor - self.data_file_keys = data_file_keys - self.image_file_extension = image_file_extension - self.video_file_extension = video_file_extension - self.repeat = repeat - - if height is not None and width is not None: - print("Height and width are fixed. Setting `dynamic_resolution` to False.") - self.dynamic_resolution = False - elif height is None and width is None: - print("Height and width are none. Setting `dynamic_resolution` to True.") - self.dynamic_resolution = True - - if metadata_path is None: - print("No metadata. Trying to generate it.") - metadata = self.generate_metadata(base_path) - print(f"{len(metadata)} lines in metadata.") - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] - elif metadata_path.endswith(".json"): - with open(metadata_path, "r") as f: - metadata = json.load(f) - self.data = metadata - else: - metadata = pd.read_csv(metadata_path) - self.data = [metadata.iloc[i].to_dict() for i in range(len(metadata))] - - - def generate_metadata(self, folder): - video_list, prompt_list = [], [] - file_set = set(os.listdir(folder)) - for file_name in file_set: - if "." not in file_name: - continue - file_ext_name = file_name.split(".")[-1].lower() - file_base_name = file_name[:-len(file_ext_name)-1] - if file_ext_name not in self.image_file_extension and file_ext_name not in self.video_file_extension: - continue - prompt_file_name = file_base_name + ".txt" - if prompt_file_name not in file_set: - continue - with open(os.path.join(folder, prompt_file_name), "r", encoding="utf-8") as f: - prompt = f.read().strip() - video_list.append(file_name) - prompt_list.append(prompt) - metadata = pd.DataFrame() - metadata["video"] = video_list - metadata["prompt"] = prompt_list - return metadata - - - def crop_and_resize(self, image, target_height, target_width): - width, height = image.size - scale = max(target_width / width, target_height / height) - image = torchvision.transforms.functional.resize( - image, - (round(height*scale), round(width*scale)), - interpolation=torchvision.transforms.InterpolationMode.BILINEAR - ) - image = torchvision.transforms.functional.center_crop(image, (target_height, target_width)) - return image - - - def get_height_width(self, image): - if self.dynamic_resolution: - width, height = image.size - if width * height > self.max_pixels: - scale = (width * height / self.max_pixels) ** 0.5 - height, width = int(height / scale), int(width / scale) - height = height // self.height_division_factor * self.height_division_factor - width = width // self.width_division_factor * self.width_division_factor - else: - height, width = self.height, self.width - return height, width - - - def get_num_frames(self, reader): - num_frames = self.num_frames - if int(reader.count_frames()) < num_frames: - num_frames = int(reader.count_frames()) - while num_frames > 1 and num_frames % self.time_division_factor != self.time_division_remainder: - num_frames -= 1 - return num_frames - - - def load_video(self, file_path): - reader = imageio.get_reader(file_path) - num_frames = self.get_num_frames(reader) - frames = [] - for frame_id in range(num_frames): - frame = reader.get_data(frame_id) - frame = Image.fromarray(frame) - frame = self.crop_and_resize(frame, *self.get_height_width(frame)) - frames.append(frame) - reader.close() - return frames - - - def load_image(self, file_path): - image = Image.open(file_path).convert("RGB") - image = self.crop_and_resize(image, *self.get_height_width(image)) - frames = [image] - return frames - - - def is_image(self, file_path): - file_ext_name = file_path.split(".")[-1] - return file_ext_name.lower() in self.image_file_extension - - - def is_video(self, file_path): - file_ext_name = file_path.split(".")[-1] - return file_ext_name.lower() in self.video_file_extension - - - def load_data(self, file_path): - if self.is_image(file_path): - return self.load_image(file_path) - elif self.is_video(file_path): - return self.load_video(file_path) - else: - return None - - - def __getitem__(self, data_id): - data = self.data[data_id % len(self.data)].copy() - for key in self.data_file_keys: - if key in data: - path = os.path.join(self.base_path, data[key]) - data[key] = self.load_data(path) - if data[key] is None: - warnings.warn(f"cannot load file {data[key]}.") - return None - return data - - - def __len__(self): - return len(self.data) * self.repeat - +from .dataset import general_dataset_parser, video_dataset_parser class DiffusionTrainingModule(torch.nn.Module): @@ -420,17 +112,14 @@ def launch_data_process_task(model: DiffusionTrainingModule, dataset, output_pat torch.save(inputs, os.path.join(output_path, "data_cache", f"{data_id}.pth")) +def flux_parser(): + dataset_parser = general_dataset_parser() -def wan_parser(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") - parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") - parser.add_argument("--max_pixels", type=int, default=1280*720, help="Maximum number of pixels per frame, used for dynamic resolution..") - parser.add_argument("--height", type=int, default=None, help="Height of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") - parser.add_argument("--width", type=int, default=None, help="Width of images or videos. Leave `height` and `width` empty to enable dynamic resolution.") - parser.add_argument("--num_frames", type=int, default=81, help="Number of frames per video. Frames are sampled from the video prefix.") - parser.add_argument("--data_file_keys", type=str, default="image,video", help="Data file keys in the metadata. Comma-separated.") - parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + parser = argparse.ArgumentParser( + description="Simple example of a training script.", + parents=[dataset_parser], + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") @@ -442,23 +131,21 @@ def wan_parser(): parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") + parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.") + parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") - parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") - parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") return parser - - -def flux_parser(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") - parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") - parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..") - parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") - parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") - parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.") - parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") +def wan_parser(): + dataset_parser = general_dataset_parser() + video_parser = video_dataset_parser() + + parser = argparse.ArgumentParser( + description="Simple example of a training script.", + parents=[dataset_parser, video_parser], + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") parser.add_argument("--learning_rate", type=float, default=1e-4, help="Learning rate.") @@ -470,23 +157,22 @@ def flux_parser(): parser.add_argument("--lora_target_modules", type=str, default="q,k,v,o,ffn.0,ffn.2", help="Which layers LoRA is added to.") parser.add_argument("--lora_rank", type=int, default=32, help="Rank of LoRA.") parser.add_argument("--extra_inputs", default=None, help="Additional model inputs, comma-separated.") - parser.add_argument("--align_to_opensource_format", default=False, action="store_true", help="Whether to align the lora format to opensource format. Only for DiT's LoRA.") - parser.add_argument("--use_gradient_checkpointing", default=False, action="store_true", help="Whether to use gradient checkpointing.") parser.add_argument("--use_gradient_checkpointing_offload", default=False, action="store_true", help="Whether to offload gradient checkpointing to CPU memory.") parser.add_argument("--gradient_accumulation_steps", type=int, default=1, help="Gradient accumulation steps.") + parser.add_argument("--max_timestep_boundary", type=float, default=1.0, help="Max timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") + parser.add_argument("--min_timestep_boundary", type=float, default=0.0, help="Min timestep boundary (for mixed models, e.g., Wan-AI/Wan2.2-I2V-A14B).") return parser def qwen_image_parser(): - parser = argparse.ArgumentParser(description="Simple example of a training script.") - parser.add_argument("--dataset_base_path", type=str, default="", required=True, help="Base path of the dataset.") - parser.add_argument("--dataset_metadata_path", type=str, default=None, help="Path to the metadata file of the dataset.") - parser.add_argument("--max_pixels", type=int, default=1024*1024, help="Maximum number of pixels per frame, used for dynamic resolution..") - parser.add_argument("--height", type=int, default=None, help="Height of images. Leave `height` and `width` empty to enable dynamic resolution.") - parser.add_argument("--width", type=int, default=None, help="Width of images. Leave `height` and `width` empty to enable dynamic resolution.") - parser.add_argument("--data_file_keys", type=str, default="image", help="Data file keys in the metadata. Comma-separated.") - parser.add_argument("--dataset_repeat", type=int, default=1, help="Number of times to repeat the dataset per epoch.") + dataset_parser = general_dataset_parser() + + parser = argparse.ArgumentParser( + description="Simple example of a training script.", + parents=[dataset_parser], + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) parser.add_argument("--model_paths", type=str, default=None, help="Paths to load models. In JSON format.") parser.add_argument("--model_id_with_origin_paths", type=str, default=None, help="Model ID with origin paths, e.g., Wan-AI/Wan2.1-T2V-1.3B:diffusion_pytorch_model*.safetensors. Comma-separated.") parser.add_argument("--tokenizer_path", type=str, default=None, help="Paths to tokenizer.") diff --git a/examples/flux/model_training/full/FLUX.1-dev.sh b/examples/flux/model_training/full/FLUX.1-dev.sh index 92549571..9edae008 100644 --- a/examples/flux/model_training/full/FLUX.1-dev.sh +++ b/examples/flux/model_training/full/FLUX.1-dev.sh @@ -1,6 +1,7 @@ accelerate launch --config_file examples/flux/model_training/full/accelerate_config.yaml examples/flux/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --dataset_input_configs "prompt:raw,image:image" \ --max_pixels 1048576 \ --dataset_repeat 400 \ --model_id_with_origin_paths "black-forest-labs/FLUX.1-dev:flux1-dev.safetensors,black-forest-labs/FLUX.1-dev:text_encoder/model.safetensors,black-forest-labs/FLUX.1-dev:text_encoder_2/,black-forest-labs/FLUX.1-dev:ae.safetensors" \ diff --git a/examples/flux/model_training/train.py b/examples/flux/model_training/train.py index ca52ff49..f77866b7 100644 --- a/examples/flux/model_training/train.py +++ b/examples/flux/model_training/train.py @@ -1,6 +1,7 @@ import torch, os, json from diffsynth.pipelines.flux_image_new import FluxImagePipeline, ModelConfig, ControlNetInput -from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, flux_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, flux_parser +from diffsynth.trainers.dataset import ImageDataset from diffsynth.models.lora import FluxLoRAConverter os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -98,7 +99,7 @@ def forward(self, data, inputs=None): if __name__ == "__main__": parser = flux_parser() args = parser.parse_args() - dataset = ImageDataset(args=args) + dataset = ImageDataset(args=args, default_key_model="flux") model = FluxTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, diff --git a/examples/qwen_image/model_training/lora/Qwen-Image.sh b/examples/qwen_image/model_training/lora/Qwen-Image.sh index 0c943918..b5968acd 100644 --- a/examples/qwen_image/model_training/lora/Qwen-Image.sh +++ b/examples/qwen_image/model_training/lora/Qwen-Image.sh @@ -1,6 +1,7 @@ accelerate launch examples/qwen_image/model_training/train.py \ --dataset_base_path data/example_image_dataset \ --dataset_metadata_path data/example_image_dataset/metadata.csv \ + --dataset_input_configs "prompt:raw,image:image" \ --max_pixels 1048576 \ --dataset_repeat 50 \ --model_id_with_origin_paths "Qwen/Qwen-Image:transformer/diffusion_pytorch_model*.safetensors,Qwen/Qwen-Image:text_encoder/model*.safetensors,Qwen/Qwen-Image:vae/diffusion_pytorch_model.safetensors" \ diff --git a/examples/qwen_image/model_training/train.py b/examples/qwen_image/model_training/train.py index 48d2d1a5..feca4ca5 100644 --- a/examples/qwen_image/model_training/train.py +++ b/examples/qwen_image/model_training/train.py @@ -1,6 +1,7 @@ import torch, os, json from diffsynth.pipelines.qwen_image import QwenImagePipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, ImageDataset, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, qwen_image_parser +from diffsynth.trainers.dataset import ImageDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -92,7 +93,7 @@ def forward(self, data, inputs=None): if __name__ == "__main__": parser = qwen_image_parser() args = parser.parse_args() - dataset = ImageDataset(args=args) + dataset = ImageDataset(args=args, default_key_model="qwen-image") model = QwenImageTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths, diff --git a/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh index d16a2871..4c7d3a3f 100644 --- a/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh +++ b/examples/wanvideo/model_training/lora/Wan2.1-T2V-1.3B.sh @@ -1,6 +1,7 @@ accelerate launch examples/wanvideo/model_training/train.py \ --dataset_base_path data/example_video_dataset \ --dataset_metadata_path data/example_video_dataset/metadata.csv \ + --dataset_input_configs "prompt,video" \ --height 480 \ --width 832 \ --dataset_repeat 100 \ diff --git a/examples/wanvideo/model_training/train.py b/examples/wanvideo/model_training/train.py index 98c737fb..bc70b596 100644 --- a/examples/wanvideo/model_training/train.py +++ b/examples/wanvideo/model_training/train.py @@ -1,6 +1,7 @@ import torch, os, json from diffsynth.pipelines.wan_video_new import WanVideoPipeline, ModelConfig -from diffsynth.trainers.utils import DiffusionTrainingModule, VideoDataset, ModelLogger, launch_training_task, wan_parser +from diffsynth.trainers.utils import DiffusionTrainingModule, ModelLogger, launch_training_task, wan_parser +from diffsynth.trainers.dataset import VideoDataset os.environ["TOKENIZERS_PARALLELISM"] = "false" @@ -104,7 +105,7 @@ def forward(self, data, inputs=None): if __name__ == "__main__": parser = wan_parser() args = parser.parse_args() - dataset = VideoDataset(args=args) + dataset = VideoDataset(args=args, default_key_model="wan") model = WanTrainingModule( model_paths=args.model_paths, model_id_with_origin_paths=args.model_id_with_origin_paths,