Skip to content

Commit 6c0b7b6

Browse files
authored
[Examples] uniform naming notations
since the in parameter `size` represents `args.resolution`, I thus replace the `args.resolution` inside DreamBoothData with `size`. And revise some notations such as `center_crop`.
1 parent c4b5d2f commit 6c0b7b6

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -934,6 +934,7 @@ class DreamBoothDataset(Dataset):
934934

935935
def __init__(
936936
self,
937+
args,
937938
instance_data_root,
938939
instance_prompt,
939940
class_prompt,
@@ -943,10 +944,8 @@ def __init__(
943944
class_num=None,
944945
size=1024,
945946
repeats=1,
946-
center_crop=False,
947947
):
948948
self.size = size
949-
self.center_crop = center_crop
950949

951950
self.instance_prompt = instance_prompt
952951
self.custom_instance_prompts = None
@@ -1035,11 +1034,11 @@ def __init__(
10351034
# flip
10361035
image = train_flip(image)
10371036
if args.center_crop:
1038-
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
1039-
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
1037+
y1 = max(0, int(round((image.height - self.size) / 2.0)))
1038+
x1 = max(0, int(round((image.width - self.size) / 2.0)))
10401039
image = train_crop(image)
10411040
else:
1042-
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
1041+
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
10431042
image = crop(image, y1, x1, h, w)
10441043
image = train_transforms(image)
10451044
self.pixel_values.append(image)
@@ -1875,6 +1874,7 @@ def load_model_hook(models, input_dir):
18751874

18761875
# Dataset and DataLoaders creation:
18771876
train_dataset = DreamBoothDataset(
1877+
args=args,
18781878
instance_data_root=args.instance_data_dir,
18791879
instance_prompt=args.instance_prompt,
18801880
train_text_encoder_ti=args.train_text_encoder_ti,
@@ -1884,7 +1884,6 @@ def load_model_hook(models, input_dir):
18841884
class_num=args.num_class_images,
18851885
size=args.resolution,
18861886
repeats=args.repeats,
1887-
center_crop=args.center_crop,
18881887
)
18891888

18901889
train_dataloader = torch.utils.data.DataLoader(

0 commit comments

Comments
 (0)