- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
fix required instance_column and image_column cast #10175
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
2c1cbb3
              c9fedf1
              ab6961e
              4a7e3f9
              b9e1ccf
              cd6bea5
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -308,7 +308,7 @@ def parse_args(input_args=None): | |
| "--instance_prompt", | ||
| type=str, | ||
| default=None, | ||
| required=True, | ||
| required=False, | ||
| help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'", | ||
| ) | ||
| parser.add_argument( | ||
|  | @@ -714,7 +714,7 @@ def __init__( | |
| # we load the training data using load_dataset | ||
| if args.dataset_name is not None: | ||
| try: | ||
| from datasets import load_dataset | ||
| from datasets import load_dataset, Image | ||
| except ImportError: | ||
| raise ImportError( | ||
| "You are trying to load your data using the datasets library. If you wish to train using custom " | ||
|  | @@ -742,6 +742,8 @@ def __init__( | |
| raise ValueError( | ||
| f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" | ||
| ) | ||
|  | ||
| dataset["train"] = dataset["train"].cast_column(image_column, Image(decode=True)) | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul fix to force casting and image column to decode to ensure this is done for feature sets where decoding automatically is not the default There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay let's have that as a comment too. Thanks for explaining!         
                  davidberenstein1957 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
| instance_images = dataset["train"][image_column] | ||
|  | ||
| if args.caption_column is None: | ||
|  | @@ -768,7 +770,7 @@ def __init__( | |
|  | ||
| instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())] | ||
| self.custom_instance_prompts = None | ||
|         
                  davidberenstein1957 marked this conversation as resolved.
              Show resolved
            Hide resolved | ||
|  | ||
|         
                  davidberenstein1957 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| self.instance_images = [] | ||
| for img in instance_images: | ||
| self.instance_images.extend(itertools.repeat(img, repeats)) | ||
|  | @@ -784,6 +786,7 @@ def __init__( | |
| ] | ||
| ) | ||
| for image in self.instance_images: | ||
|  | ||
|         
                  davidberenstein1957 marked this conversation as resolved.
              Outdated
          
            Show resolved
            Hide resolved | ||
| image = exif_transpose(image) | ||
| if not image.mode == "RGB": | ||
| image = image.convert("RGB") | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.