@@ -971,6 +971,7 @@ class DreamBoothDataset(Dataset):
971971
972972    def  __init__ (
973973        self ,
974+         args ,
974975        instance_data_root ,
975976        instance_prompt ,
976977        class_prompt ,
@@ -980,10 +981,8 @@ def __init__(
980981        class_num = None ,
981982        size = 1024 ,
982983        repeats = 1 ,
983-         center_crop = False ,
984984    ):
985985        self .size  =  size 
986-         self .center_crop  =  center_crop 
987986
988987        self .instance_prompt  =  instance_prompt 
989988        self .custom_instance_prompts  =  None 
@@ -1058,7 +1057,7 @@ def __init__(
10581057        if  interpolation  is  None :
10591058            raise  ValueError (f"Unsupported interpolation mode { interpolation = }  )
10601059        train_resize  =  transforms .Resize (size , interpolation = interpolation )
1061-         train_crop  =  transforms .CenterCrop (size ) if  center_crop  else  transforms .RandomCrop (size )
1060+         train_crop  =  transforms .CenterCrop (size ) if  args . center_crop  else  transforms .RandomCrop (size )
10621061        train_flip  =  transforms .RandomHorizontalFlip (p = 1.0 )
10631062        train_transforms  =  transforms .Compose (
10641063            [
@@ -1075,11 +1074,11 @@ def __init__(
10751074                # flip 
10761075                image  =  train_flip (image )
10771076            if  args .center_crop :
1078-                 y1  =  max (0 , int (round ((image .height  -  args . resolution ) /  2.0 )))
1079-                 x1  =  max (0 , int (round ((image .width  -  args . resolution ) /  2.0 )))
1077+                 y1  =  max (0 , int (round ((image .height  -  self . size ) /  2.0 )))
1078+                 x1  =  max (0 , int (round ((image .width  -  self . size ) /  2.0 )))
10801079                image  =  train_crop (image )
10811080            else :
1082-                 y1 , x1 , h , w  =  train_crop .get_params (image , (args . resolution ,  args . resolution ))
1081+                 y1 , x1 , h , w  =  train_crop .get_params (image , (self . size ,  self . size ))
10831082                image  =  crop (image , y1 , x1 , h , w )
10841083            image  =  train_transforms (image )
10851084            self .pixel_values .append (image )
@@ -1102,7 +1101,7 @@ def __init__(
11021101        self .image_transforms  =  transforms .Compose (
11031102            [
11041103                transforms .Resize (size , interpolation = interpolation ),
1105-                 transforms .CenterCrop (size ) if  center_crop  else  transforms .RandomCrop (size ),
1104+                 transforms .CenterCrop (size ) if  args . center_crop  else  transforms .RandomCrop (size ),
11061105                transforms .ToTensor (),
11071106                transforms .Normalize ([0.5 ], [0.5 ]),
11081107            ]
@@ -1827,6 +1826,7 @@ def load_model_hook(models, input_dir):
18271826
18281827    # Dataset and DataLoaders creation: 
18291828    train_dataset  =  DreamBoothDataset (
1829+         args = args ,
18301830        instance_data_root = args .instance_data_dir ,
18311831        instance_prompt = args .instance_prompt ,
18321832        train_text_encoder_ti = args .train_text_encoder_ti ,
@@ -1836,7 +1836,6 @@ def load_model_hook(models, input_dir):
18361836        class_num = args .num_class_images ,
18371837        size = args .resolution ,
18381838        repeats = args .repeats ,
1839-         center_crop = args .center_crop ,
18401839    )
18411840
18421841    train_dataloader  =  torch .utils .data .DataLoader (
0 commit comments