@@ -934,6 +934,7 @@ class DreamBoothDataset(Dataset):
934934
935935 def __init__ (
936936 self ,
937+ args ,
937938 instance_data_root ,
938939 instance_prompt ,
939940 class_prompt ,
@@ -943,10 +944,8 @@ def __init__(
943944 class_num = None ,
944945 size = 1024 ,
945946 repeats = 1 ,
946- center_crop = False ,
947947 ):
948948 self .size = size
949- self .center_crop = center_crop
950949
951950 self .instance_prompt = instance_prompt
952951 self .custom_instance_prompts = None
@@ -1035,11 +1034,11 @@ def __init__(
10351034 # flip
10361035 image = train_flip (image )
10371036 if args .center_crop :
1038- y1 = max (0 , int (round ((image .height - args . resolution ) / 2.0 )))
1039- x1 = max (0 , int (round ((image .width - args . resolution ) / 2.0 )))
1037+ y1 = max (0 , int (round ((image .height - self . size ) / 2.0 )))
1038+ x1 = max (0 , int (round ((image .width - self . size ) / 2.0 )))
10401039 image = train_crop (image )
10411040 else :
1042- y1 , x1 , h , w = train_crop .get_params (image , (args . resolution , args . resolution ))
1041+ y1 , x1 , h , w = train_crop .get_params (image , (self . size , self . size ))
10431042 image = crop (image , y1 , x1 , h , w )
10441043 image = train_transforms (image )
10451044 self .pixel_values .append (image )
@@ -1875,6 +1874,7 @@ def load_model_hook(models, input_dir):
18751874
18761875 # Dataset and DataLoaders creation:
18771876 train_dataset = DreamBoothDataset (
1877+ args = args ,
18781878 instance_data_root = args .instance_data_dir ,
18791879 instance_prompt = args .instance_prompt ,
18801880 train_text_encoder_ti = args .train_text_encoder_ti ,
@@ -1884,7 +1884,6 @@ def load_model_hook(models, input_dir):
18841884 class_num = args .num_class_images ,
18851885 size = args .resolution ,
18861886 repeats = args .repeats ,
1887- center_crop = args .center_crop ,
18881887 )
18891888
18901889 train_dataloader = torch .utils .data .DataLoader (
0 commit comments