Skip to content

Commit 1b9f29e

Browse files
sayakpaulhlky
andauthored
Update examples/dreambooth/train_dreambooth_lora_sdxl.py
Co-authored-by: hlky <[email protected]>
1 parent 044eb83 commit 1b9f29e

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -798,10 +798,10 @@ def __init__(
798798
self.crop_top_lefts = []
799799
self.pixel_values = []
800800

801-
if args.image_interpolation_mode == "bilinear":
802-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
803-
else:
804-
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.LANCZOS)
801+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
802+
if interpolation is None:
803+
raise ValueError(f"Unsupported interpolation mode.")
804+
train_resize = transforms.Resize(size, interpolation=interpolation)
805805

806806
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
807807
train_flip = transforms.RandomHorizontalFlip(p=1.0)

0 commit comments

Comments
 (0)