Skip to content

refactor training dataset #740

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

mi804
Copy link
Collaborator

@mi804 mi804 commented Aug 5, 2025

No description provided.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @mi804, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

I've refactored the training dataset handling to improve modularity and maintainability. This involved creating dedicated files for dataset classes and data type configurations, centralizing common logic, and streamlining argument parsing. The changes aim to make it easier to add new dataset types and integrate them with different training pipelines, while also cleaning up existing utility files and example scripts.

Highlights

  • Centralized Dataset Management: I've introduced new BaseDataset, ImageDataset, and VideoDataset classes in diffsynth/trainers/dataset.py. These classes provide a unified and extensible framework for handling various types of training data, including dynamic resolution adjustments and metadata processing.
  • Modular Data Type Configuration and Loaders: A new file, diffsynth/trainers/dataset_key_configs.py, has been added to define data type configurations and corresponding loader functions. This modular approach allows for easy extension and management of how different data keys (e.g., 'image', 'video', 'prompt') are processed for various models.
  • Refactored Utility Functions and Argument Parsers: The existing ImageDataset and VideoDataset implementations were removed from diffsynth/trainers/utils.py to eliminate redundancy. The argument parsers (flux_parser, wan_parser, qwen_image_parser) in utils.py now leverage the new general_dataset_parser and video_dataset_parser from dataset.py, streamlining argument definition.
  • Updated Training Examples for New Dataset Structure: All relevant example training scripts and their corresponding Python files have been updated to utilize the new dataset structure. This includes adjusting import paths for dataset classes and passing default_key_model and dataset_input_configs to ensure compatibility and proper data handling with the refactored dataset logic.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a significant and beneficial refactoring of the dataset handling logic. By creating a BaseDataset and moving ImageDataset and VideoDataset to be subclasses, the code is now more modular and reusable. The introduction of dataset_key_configs provides a flexible way to manage data loading for different models. The refactoring of argparse functions in utils.py also effectively reduces code duplication.

My review includes a few suggestions to fix a critical bug in metadata generation and to improve error handling and code structure for better maintainability and debugging.

Comment on lines +63 to +66
def load_meta(self, metadata_path, base_path=None):
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a critical issue here when metadata_path is None. The base_path parameter in load_meta is unused and defaults to None. This None value is then passed to generate_metadata, which will cause os.listdir(None) to raise a TypeError.

To fix this, you should remove the base_path parameter from load_meta and use self.base_path when calling generate_metadata.

Suggested change
def load_meta(self, metadata_path, base_path=None):
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(base_path)
def load_meta(self, metadata_path):
if metadata_path is None:
print("No metadata. Trying to generate it.")
metadata = self.generate_metadata(self.base_path)

Comment on lines +213 to +214
except:
warnings.warn(f"Error loading data with id {data_id}. Replacing with another data.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Using a bare except here swallows the original exception, which makes debugging data loading issues very challenging. You should catch the specific Exception and include its message in the warning to provide more context on what went wrong.

Suggested change
except:
warnings.warn(f"Error loading data with id {data_id}. Replacing with another data.")
except Exception as e:
warnings.warn(f"Error loading data with id {data_id}: {e}. Replacing with another data.")

Comment on lines +135 to +136
except:
continue
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except can hide unexpected errors and make debugging difficult. The most likely exception in this block is a KeyError if a required key is missing from the data. It's better to catch this specific exception.

Suggested change
except:
continue
except KeyError:
continue

Comment on lines +152 to +171
def get_loader(data_type):
"""
Get the loader function for the given data type.
:param data_type: The data type to get the loader for
:return: The loader function
"""
if data_type == 'raw':
return raw_loader
elif data_type == 'int':
return int_loader
elif data_type == 'float':
return float_loader
elif data_type == 'image':
return image_loader
elif data_type == 'video':
return video_loader
elif data_type == 'tensor':
return tensor_loader
else:
raise ValueError(f"Unsupported data type: {data_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/elif chain for selecting a loader function can be refactored to use a dictionary mapping. This improves readability, maintainability, and is slightly more performant as it's an O(1) lookup.

You could define a LOADER_MAP constant at the module level and then simplify this function.

For example:

LOADER_MAP = {
    'raw': raw_loader,
    'int': int_loader,
    'float': float_loader,
    'image': image_loader,
    'video': video_loader,
    'tensor': tensor_loader,
}

def get_loader(data_type):
    """
    Get the loader function for the given data type.
    :param data_type: The data type to get the loader for
    :return: The loader function
    """
    loader = LOADER_MAP.get(data_type)
    if loader is None:
        raise ValueError(f"Unsupported data type: {data_type}")
    return loader

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant