-
Notifications
You must be signed in to change notification settings - Fork 866
refactor training dataset #740
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @mi804, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
I've refactored the training dataset handling to improve modularity and maintainability. This involved creating dedicated files for dataset classes and data type configurations, centralizing common logic, and streamlining argument parsing. The changes aim to make it easier to add new dataset types and integrate them with different training pipelines, while also cleaning up existing utility files and example scripts.
Highlights
- Centralized Dataset Management: I've introduced new
BaseDataset
,ImageDataset
, andVideoDataset
classes indiffsynth/trainers/dataset.py
. These classes provide a unified and extensible framework for handling various types of training data, including dynamic resolution adjustments and metadata processing. - Modular Data Type Configuration and Loaders: A new file,
diffsynth/trainers/dataset_key_configs.py
, has been added to define data type configurations and corresponding loader functions. This modular approach allows for easy extension and management of how different data keys (e.g., 'image', 'video', 'prompt') are processed for various models. - Refactored Utility Functions and Argument Parsers: The existing
ImageDataset
andVideoDataset
implementations were removed fromdiffsynth/trainers/utils.py
to eliminate redundancy. The argument parsers (flux_parser
,wan_parser
,qwen_image_parser
) inutils.py
now leverage the newgeneral_dataset_parser
andvideo_dataset_parser
fromdataset.py
, streamlining argument definition. - Updated Training Examples for New Dataset Structure: All relevant example training scripts and their corresponding Python files have been updated to utilize the new dataset structure. This includes adjusting import paths for dataset classes and passing
default_key_model
anddataset_input_configs
to ensure compatibility and proper data handling with the refactored dataset logic.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and beneficial refactoring of the dataset handling logic. By creating a BaseDataset
and moving ImageDataset
and VideoDataset
to be subclasses, the code is now more modular and reusable. The introduction of dataset_key_configs
provides a flexible way to manage data loading for different models. The refactoring of argparse
functions in utils.py
also effectively reduces code duplication.
My review includes a few suggestions to fix a critical bug in metadata generation and to improve error handling and code structure for better maintainability and debugging.
def load_meta(self, metadata_path, base_path=None): | ||
if metadata_path is None: | ||
print("No metadata. Trying to generate it.") | ||
metadata = self.generate_metadata(base_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a critical issue here when metadata_path
is None
. The base_path
parameter in load_meta
is unused and defaults to None
. This None
value is then passed to generate_metadata
, which will cause os.listdir(None)
to raise a TypeError
.
To fix this, you should remove the base_path
parameter from load_meta
and use self.base_path
when calling generate_metadata
.
def load_meta(self, metadata_path, base_path=None): | |
if metadata_path is None: | |
print("No metadata. Trying to generate it.") | |
metadata = self.generate_metadata(base_path) | |
def load_meta(self, metadata_path): | |
if metadata_path is None: | |
print("No metadata. Trying to generate it.") | |
metadata = self.generate_metadata(self.base_path) |
except: | ||
warnings.warn(f"Error loading data with id {data_id}. Replacing with another data.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using a bare except
here swallows the original exception, which makes debugging data loading issues very challenging. You should catch the specific Exception
and include its message in the warning to provide more context on what went wrong.
except: | |
warnings.warn(f"Error loading data with id {data_id}. Replacing with another data.") | |
except Exception as e: | |
warnings.warn(f"Error loading data with id {data_id}: {e}. Replacing with another data.") |
except: | ||
continue |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_loader(data_type): | ||
""" | ||
Get the loader function for the given data type. | ||
:param data_type: The data type to get the loader for | ||
:return: The loader function | ||
""" | ||
if data_type == 'raw': | ||
return raw_loader | ||
elif data_type == 'int': | ||
return int_loader | ||
elif data_type == 'float': | ||
return float_loader | ||
elif data_type == 'image': | ||
return image_loader | ||
elif data_type == 'video': | ||
return video_loader | ||
elif data_type == 'tensor': | ||
return tensor_loader | ||
else: | ||
raise ValueError(f"Unsupported data type: {data_type}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This if/elif
chain for selecting a loader function can be refactored to use a dictionary mapping. This improves readability, maintainability, and is slightly more performant as it's an O(1) lookup.
You could define a LOADER_MAP
constant at the module level and then simplify this function.
For example:
LOADER_MAP = {
'raw': raw_loader,
'int': int_loader,
'float': float_loader,
'image': image_loader,
'video': video_loader,
'tensor': tensor_loader,
}
def get_loader(data_type):
"""
Get the loader function for the given data type.
:param data_type: The data type to get the loader for
:return: The loader function
"""
loader = LOADER_MAP.get(data_type)
if loader is None:
raise ValueError(f"Unsupported data type: {data_type}")
return loader
No description provided.