Skip to content

Commit cb53d9c

Browse files
committed
Merge remote-tracking branch 'origin/hidream-followup' into hidream-followup
2 parents 1efdc2a + 0e6fa1b commit cb53d9c

File tree

1 file changed

+13
-1
lines changed

1 file changed

+13
-1
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,15 @@ def parse_args(input_args=None):
618618
),
619619
)
620620
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
621+
parser.add_argument(
622+
"--image_interpolation_mode",
623+
type=str,
624+
default="lanczos",
625+
choices=[
626+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
627+
],
628+
help="The image interpolation method to use for resizing images.",
629+
)
621630

622631
if input_args is not None:
623632
args = parser.parse_args(input_args)
@@ -737,7 +746,10 @@ def __init__(
737746
self.instance_images.extend(itertools.repeat(img, repeats))
738747

739748
self.pixel_values = []
740-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
749+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
750+
if interpolation is None:
751+
raise ValueError(f"Unsupported interpolation mode {interpolation=}.")
752+
train_resize = transforms.Resize(size, interpolation=interpolation)
741753
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
742754
train_flip = transforms.RandomHorizontalFlip(p=1.0)
743755
train_transforms = transforms.Compose(

0 commit comments

Comments
 (0)