Skip to content

Commit 9bf5c56

Browse files
authored
Revert "[Examples] Uniform notations in train_flux_lora (#10011)"
This reverts commit 173e1b1.
1 parent 173e1b1 commit 9bf5c56

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -971,7 +971,6 @@ class DreamBoothDataset(Dataset):
971971

972972
def __init__(
973973
self,
974-
args,
975974
instance_data_root,
976975
instance_prompt,
977976
class_prompt,
@@ -981,8 +980,10 @@ def __init__(
981980
class_num=None,
982981
size=1024,
983982
repeats=1,
983+
center_crop=False,
984984
):
985985
self.size = size
986+
self.center_crop = center_crop
986987

987988
self.instance_prompt = instance_prompt
988989
self.custom_instance_prompts = None
@@ -1074,11 +1075,11 @@ def __init__(
10741075
# flip
10751076
image = train_flip(image)
10761077
if args.center_crop:
1077-
y1 = max(0, int(round((image.height - self.size) / 2.0)))
1078-
x1 = max(0, int(round((image.width - self.size) / 2.0)))
1078+
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
1079+
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
10791080
image = train_crop(image)
10801081
else:
1081-
y1, x1, h, w = train_crop.get_params(image, (self.size, self.size))
1082+
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
10821083
image = crop(image, y1, x1, h, w)
10831084
image = train_transforms(image)
10841085
self.pixel_values.append(image)
@@ -1826,7 +1827,6 @@ def load_model_hook(models, input_dir):
18261827

18271828
# Dataset and DataLoaders creation:
18281829
train_dataset = DreamBoothDataset(
1829-
args=args,
18301830
instance_data_root=args.instance_data_dir,
18311831
instance_prompt=args.instance_prompt,
18321832
train_text_encoder_ti=args.train_text_encoder_ti,
@@ -1836,6 +1836,7 @@ def load_model_hook(models, input_dir):
18361836
class_num=args.num_class_images,
18371837
size=args.resolution,
18381838
repeats=args.repeats,
1839+
center_crop=args.center_crop,
18391840
)
18401841

18411842
train_dataloader = torch.utils.data.DataLoader(

0 commit comments

Comments
 (0)