@@ -971,7 +971,6 @@ class DreamBoothDataset(Dataset):
971971
972972    def  __init__ (
973973        self ,
974-         args ,
975974        instance_data_root ,
976975        instance_prompt ,
977976        class_prompt ,
@@ -981,8 +980,10 @@ def __init__(
981980        class_num = None ,
982981        size = 1024 ,
983982        repeats = 1 ,
983+         center_crop = False ,
984984    ):
985985        self .size  =  size 
986+         self .center_crop  =  center_crop 
986987
987988        self .instance_prompt  =  instance_prompt 
988989        self .custom_instance_prompts  =  None 
@@ -1074,11 +1075,11 @@ def __init__(
10741075                # flip 
10751076                image  =  train_flip (image )
10761077            if  args .center_crop :
1077-                 y1  =  max (0 , int (round ((image .height  -  self . size ) /  2.0 )))
1078-                 x1  =  max (0 , int (round ((image .width  -  self . size ) /  2.0 )))
1078+                 y1  =  max (0 , int (round ((image .height  -  args . resolution ) /  2.0 )))
1079+                 x1  =  max (0 , int (round ((image .width  -  args . resolution ) /  2.0 )))
10791080                image  =  train_crop (image )
10801081            else :
1081-                 y1 , x1 , h , w  =  train_crop .get_params (image , (self . size ,  self . size ))
1082+                 y1 , x1 , h , w  =  train_crop .get_params (image , (args . resolution ,  args . resolution ))
10821083                image  =  crop (image , y1 , x1 , h , w )
10831084            image  =  train_transforms (image )
10841085            self .pixel_values .append (image )
@@ -1826,7 +1827,6 @@ def load_model_hook(models, input_dir):
18261827
18271828    # Dataset and DataLoaders creation: 
18281829    train_dataset  =  DreamBoothDataset (
1829-         args = args ,
18301830        instance_data_root = args .instance_data_dir ,
18311831        instance_prompt = args .instance_prompt ,
18321832        train_text_encoder_ti = args .train_text_encoder_ti ,
@@ -1836,6 +1836,7 @@ def load_model_hook(models, input_dir):
18361836        class_num = args .num_class_images ,
18371837        size = args .resolution ,
18381838        repeats = args .repeats ,
1839+         center_crop = args .center_crop ,
18391840    )
18401841
18411842    train_dataloader  =  torch .utils .data .DataLoader (
0 commit comments