@@ -308,7 +308,7 @@ def parse_args(input_args=None):
308308        "--instance_prompt" ,
309309        type = str ,
310310        default = None ,
311-         required = True ,
311+         required = False ,
312312        help = "The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'" ,
313313    )
314314    parser .add_argument (
@@ -714,7 +714,7 @@ def __init__(
714714        # we load the training data using load_dataset 
715715        if  args .dataset_name  is  not None :
716716            try :
717-                 from  datasets  import  load_dataset 
717+                 from  datasets  import  load_dataset ,  Image 
718718            except  ImportError :
719719                raise  ImportError (
720720                    "You are trying to load your data using the datasets library. If you wish to train using custom " 
@@ -742,6 +742,8 @@ def __init__(
742742                    raise  ValueError (
743743                        f"`--image_column` value '{ args .image_column } { ', ' .join (column_names )}  
744744                    )
745+                     
746+             dataset ["train" ] =  dataset ["train" ].cast_column (image_column , Image (decode = True ))
745747            instance_images  =  dataset ["train" ][image_column ]
746748
747749            if  args .caption_column  is  None :
@@ -768,7 +770,7 @@ def __init__(
768770
769771            instance_images  =  [Image .open (path ) for  path  in  list (Path (instance_data_root ).iterdir ())]
770772            self .custom_instance_prompts  =  None 
771- 
773+          
772774        self .instance_images  =  []
773775        for  img  in  instance_images :
774776            self .instance_images .extend (itertools .repeat (img , repeats ))
@@ -784,6 +786,7 @@ def __init__(
784786            ]
785787        )
786788        for  image  in  self .instance_images :
789+             
787790            image  =  exif_transpose (image )
788791            if  not  image .mode  ==  "RGB" :
789792                image  =  image .convert ("RGB" )
0 commit comments