Skip to content

Commit 6f7011a

Browse files
authored
Merge branch 'main' into enable-hotswap-testing-ci
2 parents c062b08 + 4a9ab65 commit 6f7011a

File tree

7 files changed

+158
-61
lines changed

7 files changed

+158
-61
lines changed

examples/dreambooth/train_dreambooth_lora_hidream.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,6 @@ def log_validation(
236236
}
237237
)
238238

239-
pipeline.to("cpu")
240239
del pipeline
241240
free_memory()
242241

examples/text_to_image/train_text_to_image.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,15 @@ def parse_args():
499499
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
500500
),
501501
)
502+
parser.add_argument(
503+
"--image_interpolation_mode",
504+
type=str,
505+
default="lanczos",
506+
choices=[
507+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
508+
],
509+
help="The image interpolation method to use for resizing images.",
510+
)
502511

503512
args = parser.parse_args()
504513
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -787,10 +796,17 @@ def tokenize_captions(examples, is_train=True):
787796
)
788797
return inputs.input_ids
789798

790-
# Preprocessing the datasets.
799+
# Get the specified interpolation method from the args
800+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
801+
802+
# Raise an error if the interpolation method is invalid
803+
if interpolation is None:
804+
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
805+
806+
# Data preprocessing transformations
791807
train_transforms = transforms.Compose(
792808
[
793-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
809+
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
794810
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
795811
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
796812
transforms.ToTensor(),

examples/text_to_image/train_text_to_image_lora.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,15 @@ def parse_args():
418418
default=4,
419419
help=("The dimension of the LoRA update matrices."),
420420
)
421+
parser.add_argument(
422+
"--image_interpolation_mode",
423+
type=str,
424+
default="lanczos",
425+
choices=[
426+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
427+
],
428+
help="The image interpolation method to use for resizing images.",
429+
)
421430

422431
args = parser.parse_args()
423432
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -649,10 +658,17 @@ def tokenize_captions(examples, is_train=True):
649658
)
650659
return inputs.input_ids
651660

652-
# Preprocessing the datasets.
661+
# Get the specified interpolation method from the args
662+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
663+
664+
# Raise an error if the interpolation method is invalid
665+
if interpolation is None:
666+
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
667+
668+
# Data preprocessing transformations
653669
train_transforms = transforms.Compose(
654670
[
655-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
671+
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
656672
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
657673
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
658674
transforms.ToTensor(),

src/diffusers/loaders/lora_pipeline.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,18 +91,19 @@ def _maybe_dequantize_weight_for_expanded_lora(model, module):
9191
)
9292

9393
weight_on_cpu = False
94-
if not module.weight.is_cuda:
94+
if module.weight.device.type == "cpu":
9595
weight_on_cpu = True
9696

97+
device = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
9798
if is_bnb_4bit_quantized:
9899
module_weight = dequantize_bnb_weight(
99-
module.weight.cuda() if weight_on_cpu else module.weight,
100+
module.weight.to(device) if weight_on_cpu else module.weight,
100101
state=module.weight.quant_state,
101102
dtype=model.dtype,
102103
).data
103104
elif is_gguf_quantized:
104105
module_weight = dequantize_gguf_tensor(
105-
module.weight.cuda() if weight_on_cpu else module.weight,
106+
module.weight.to(device) if weight_on_cpu else module.weight,
106107
)
107108
module_weight = module_weight.to(model.dtype)
108109
else:

src/diffusers/pipelines/onnx_utils.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,11 @@ def load_model(path: Union[str, Path], provider=None, sess_options=None, provide
7575
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
7676
provider = "CPUExecutionProvider"
7777

78+
if provider_options is None:
79+
provider_options = []
80+
elif not isinstance(provider_options, list):
81+
provider_options = [provider_options]
82+
7883
return ort.InferenceSession(
7984
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
8085
)
@@ -174,7 +179,10 @@ def _from_pretrained(
174179
# load model from local directory
175180
if os.path.isdir(model_id):
176181
model = OnnxRuntimeModel.load_model(
177-
Path(model_id, model_file_name).as_posix(), provider=provider, sess_options=sess_options
182+
Path(model_id, model_file_name).as_posix(),
183+
provider=provider,
184+
sess_options=sess_options,
185+
provider_options=kwargs.pop("provider_options"),
178186
)
179187
kwargs["model_save_dir"] = Path(model_id)
180188
# load model from hub
@@ -190,7 +198,12 @@ def _from_pretrained(
190198
)
191199
kwargs["model_save_dir"] = Path(model_cache_path).parent
192200
kwargs["latest_model_name"] = Path(model_cache_path).name
193-
model = OnnxRuntimeModel.load_model(model_cache_path, provider=provider, sess_options=sess_options)
201+
model = OnnxRuntimeModel.load_model(
202+
model_cache_path,
203+
provider=provider,
204+
sess_options=sess_options,
205+
provider_options=kwargs.pop("provider_options"),
206+
)
194207
return cls(model=model, **kwargs)
195208

196209
@classmethod

src/diffusers/quantizers/gguf/gguf_quantizer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,14 @@ def _dequantize(self, model):
150150
is_model_on_cpu = model.device.type == "cpu"
151151
if is_model_on_cpu:
152152
logger.info(
153-
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
153+
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to accelerator. After dequantization, will move the model back to CPU again to preserve the previous device."
154154
)
155-
model.to(torch.cuda.current_device())
155+
device = (
156+
torch.accelerator.current_accelerator()
157+
if hasattr(torch, "accelerator")
158+
else torch.cuda.current_device()
159+
)
160+
model.to(device)
156161

157162
model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
158163
if is_model_on_cpu:

0 commit comments

Comments
 (0)