Skip to content

Commit 274b84e

Browse files
authored
Merge branch 'main' into groupwise-offloading
2 parents 6be43b8 + 8ae8008 commit 274b84e

File tree

8 files changed

+205
-17
lines changed

8 files changed

+205
-17
lines changed

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
519519
remaining_keys = list(sds_sd.keys())
520520
te_state_dict = {}
521521
if remaining_keys:
522-
if not all(k.startswith("lora_te1") for k in remaining_keys):
522+
if not all(k.startswith("lora_te") for k in remaining_keys):
523523
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
524524
for key in remaining_keys:
525525
if not key.endswith("lora_down.weight"):
@@ -558,6 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
558558
new_state_dict = {**ait_sd, **te_state_dict}
559559
return new_state_dict
560560

561+
def _convert_mixture_state_dict_to_diffusers(state_dict):
562+
new_state_dict = {}
563+
564+
def _convert(original_key, diffusers_key, state_dict, new_state_dict):
565+
down_key = f"{original_key}.lora_down.weight"
566+
down_weight = state_dict.pop(down_key)
567+
lora_rank = down_weight.shape[0]
568+
569+
up_weight_key = f"{original_key}.lora_up.weight"
570+
up_weight = state_dict.pop(up_weight_key)
571+
572+
alpha_key = f"{original_key}.alpha"
573+
alpha = state_dict.pop(alpha_key)
574+
575+
# scale weight by alpha and dim
576+
scale = alpha / lora_rank
577+
# calculate scale_down and scale_up
578+
scale_down = scale
579+
scale_up = 1.0
580+
while scale_down * 2 < scale_up:
581+
scale_down *= 2
582+
scale_up /= 2
583+
down_weight = down_weight * scale_down
584+
up_weight = up_weight * scale_up
585+
586+
diffusers_down_key = f"{diffusers_key}.lora_A.weight"
587+
new_state_dict[diffusers_down_key] = down_weight
588+
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight
589+
590+
all_unique_keys = {
591+
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
592+
}
593+
all_unique_keys = sorted(all_unique_keys)
594+
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"
595+
596+
for k in all_unique_keys:
597+
if k.startswith("lora_transformer_single_transformer_blocks_"):
598+
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
599+
diffusers_key = f"single_transformer_blocks.{i}"
600+
elif k.startswith("lora_transformer_transformer_blocks_"):
601+
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
602+
diffusers_key = f"transformer_blocks.{i}"
603+
else:
604+
raise NotImplementedError
605+
606+
if "attn_" in k:
607+
if "_to_out_0" in k:
608+
diffusers_key += ".attn.to_out.0"
609+
elif "_to_add_out" in k:
610+
diffusers_key += ".attn.to_add_out"
611+
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
612+
remaining = k.split("attn_")[-1]
613+
diffusers_key += f".attn.{remaining}"
614+
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
615+
remaining = k.split("attn_")[-1]
616+
diffusers_key += f".attn.{remaining}"
617+
618+
if diffusers_key == f"transformer_blocks.{i}":
619+
print(k, diffusers_key)
620+
_convert(k, diffusers_key, state_dict, new_state_dict)
621+
622+
if len(state_dict) > 0:
623+
raise ValueError(
624+
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
625+
)
626+
627+
new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
628+
return new_state_dict
629+
630+
# This is weird.
631+
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
632+
# has both `peft` and non-peft state dict.
633+
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
634+
if has_peft_state_dict:
635+
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
636+
return state_dict
637+
# Another weird one.
638+
has_mixture = any(
639+
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
640+
)
641+
if has_mixture:
642+
return _convert_mixture_state_dict_to_diffusers(state_dict)
561643
return _convert_sd_scripts_to_ai_toolkit(state_dict)
562644

563645

src/diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@
3636
def prepare_causal_attention_mask(
3737
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
3838
) -> torch.Tensor:
39-
seq_len = num_frames * height_width
40-
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
41-
for i in range(seq_len):
42-
i_frame = i // height_width
43-
mask[i, : (i_frame + 1) * height_width] = 0
39+
indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
40+
indices_blocks = indices.repeat_interleave(height_width)
41+
x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
42+
mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)
43+
4444
if batch_size is not None:
4545
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
4646
return mask

src/diffusers/models/modeling_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
3232
from huggingface_hub.utils import validate_hf_hub_args
3333
from torch import Tensor, nn
34+
from typing_extensions import Self
3435

3536
from .. import __version__
3637
from ..hooks import apply_group_offloading, apply_layerwise_casting
@@ -665,7 +666,7 @@ def dequantize(self):
665666

666667
@classmethod
667668
@validate_hf_hub_args
668-
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
669+
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
669670
r"""
670671
Instantiate a pretrained PyTorch model from a pretrained model configuration.
671672

src/diffusers/schedulers/scheduling_edm_euler.py

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import math
1616
from dataclasses import dataclass
17-
from typing import Optional, Tuple, Union
17+
from typing import List, Optional, Tuple, Union
1818

1919
import torch
2020

@@ -77,6 +77,9 @@ class EDMEulerScheduler(SchedulerMixin, ConfigMixin):
7777
Video](https://imagen.research.google/video/paper.pdf) paper).
7878
rho (`float`, *optional*, defaults to 7.0):
7979
The rho parameter used for calculating the Karras sigma schedule, which is set to 7.0 in the EDM paper [1].
80+
final_sigmas_type (`str`, defaults to `"zero"`):
81+
The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final
82+
sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0.
8083
"""
8184

8285
_compatibles = []
@@ -92,22 +95,32 @@ def __init__(
9295
num_train_timesteps: int = 1000,
9396
prediction_type: str = "epsilon",
9497
rho: float = 7.0,
98+
final_sigmas_type: str = "zero", # can be "zero" or "sigma_min"
9599
):
96100
if sigma_schedule not in ["karras", "exponential"]:
97101
raise ValueError(f"Wrong value for provided for `{sigma_schedule=}`.`")
98102

99103
# setable values
100104
self.num_inference_steps = None
101105

102-
ramp = torch.linspace(0, 1, num_train_timesteps)
106+
sigmas = torch.arange(num_train_timesteps + 1) / num_train_timesteps
103107
if sigma_schedule == "karras":
104-
sigmas = self._compute_karras_sigmas(ramp)
108+
sigmas = self._compute_karras_sigmas(sigmas)
105109
elif sigma_schedule == "exponential":
106-
sigmas = self._compute_exponential_sigmas(ramp)
110+
sigmas = self._compute_exponential_sigmas(sigmas)
107111

108112
self.timesteps = self.precondition_noise(sigmas)
109113

110-
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
114+
if self.config.final_sigmas_type == "sigma_min":
115+
sigma_last = sigmas[-1]
116+
elif self.config.final_sigmas_type == "zero":
117+
sigma_last = 0
118+
else:
119+
raise ValueError(
120+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
121+
)
122+
123+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
111124

112125
self.is_scale_input_called = False
113126

@@ -197,7 +210,12 @@ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.T
197210
self.is_scale_input_called = True
198211
return sample
199212

200-
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
213+
def set_timesteps(
214+
self,
215+
num_inference_steps: int = None,
216+
device: Union[str, torch.device] = None,
217+
sigmas: Optional[Union[torch.Tensor, List[float]]] = None,
218+
):
201219
"""
202220
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
203221
@@ -206,19 +224,36 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
206224
The number of diffusion steps used when generating samples with a pre-trained model.
207225
device (`str` or `torch.device`, *optional*):
208226
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
227+
sigmas (`Union[torch.Tensor, List[float]]`, *optional*):
228+
Custom sigmas to use for the denoising process. If not defined, the default behavior when
229+
`num_inference_steps` is passed will be used.
209230
"""
210231
self.num_inference_steps = num_inference_steps
211232

212-
ramp = torch.linspace(0, 1, self.num_inference_steps)
233+
if sigmas is None:
234+
sigmas = torch.linspace(0, 1, self.num_inference_steps)
235+
elif isinstance(sigmas, float):
236+
sigmas = torch.tensor(sigmas, dtype=torch.float32)
237+
else:
238+
sigmas = sigmas
213239
if self.config.sigma_schedule == "karras":
214-
sigmas = self._compute_karras_sigmas(ramp)
240+
sigmas = self._compute_karras_sigmas(sigmas)
215241
elif self.config.sigma_schedule == "exponential":
216-
sigmas = self._compute_exponential_sigmas(ramp)
242+
sigmas = self._compute_exponential_sigmas(sigmas)
217243

218244
sigmas = sigmas.to(dtype=torch.float32, device=device)
219245
self.timesteps = self.precondition_noise(sigmas)
220246

221-
self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
247+
if self.config.final_sigmas_type == "sigma_min":
248+
sigma_last = sigmas[-1]
249+
elif self.config.final_sigmas_type == "zero":
250+
sigma_last = 0
251+
else:
252+
raise ValueError(
253+
f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}"
254+
)
255+
256+
self.sigmas = torch.cat([sigmas, torch.full((1,), fill_value=sigma_last, device=sigmas.device)])
222257
self._step_index = None
223258
self._begin_index = None
224259
self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication

tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from diffusers import AutoencoderKLHunyuanVideo
21+
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
2122
from diffusers.utils.testing_utils import (
2223
enable_full_determinism,
2324
floats_tensor,
@@ -182,3 +183,28 @@ def test_forward_with_norm_groups(self):
182183
@unittest.skip("Unsupported test.")
183184
def test_outputs_equivalence(self):
184185
pass
186+
187+
def test_prepare_causal_attention_mask(self):
188+
def prepare_causal_attention_mask_orig(
189+
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
190+
) -> torch.Tensor:
191+
seq_len = num_frames * height_width
192+
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
193+
for i in range(seq_len):
194+
i_frame = i // height_width
195+
mask[i, : (i_frame + 1) * height_width] = 0
196+
if batch_size is not None:
197+
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
198+
return mask
199+
200+
# test with some odd shapes
201+
original_mask = prepare_causal_attention_mask_orig(
202+
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
203+
)
204+
new_mask = prepare_causal_attention_mask(
205+
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
206+
)
207+
self.assertTrue(
208+
torch.allclose(original_mask, new_mask),
209+
"Causal attention mask should be the same",
210+
)

tests/models/autoencoders/test_models_autoencoder_oobleck.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,12 @@ def test_forward_with_norm_groups(self):
114114
def test_set_attn_processor_for_determinism(self):
115115
return
116116

117+
@unittest.skip(
118+
"Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'"
119+
)
120+
def test_layerwise_casting_training(self):
121+
return super().test_layerwise_casting_training()
122+
117123
@unittest.skip(
118124
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
119125
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"

tests/models/test_modeling_common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1338,6 +1338,36 @@ def test_variant_sharded_ckpt_right_format(self):
13381338
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
13391339
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)
13401340

1341+
def test_layerwise_casting_training(self):
1342+
def test_fn(storage_dtype, compute_dtype):
1343+
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
1344+
return
1345+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
1346+
1347+
model = self.model_class(**init_dict)
1348+
model = model.to(torch_device, dtype=compute_dtype)
1349+
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
1350+
model.train()
1351+
1352+
inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
1353+
with torch.amp.autocast(device_type=torch.device(torch_device).type):
1354+
output = model(**inputs_dict)
1355+
1356+
if isinstance(output, dict):
1357+
output = output.to_tuple()[0]
1358+
1359+
input_tensor = inputs_dict[self.main_input_name]
1360+
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
1361+
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
1362+
loss = torch.nn.functional.mse_loss(output, noise)
1363+
1364+
loss.backward()
1365+
1366+
test_fn(torch.float16, torch.float32)
1367+
test_fn(torch.float8_e4m3fn, torch.float32)
1368+
test_fn(torch.float8_e5m2, torch.float32)
1369+
test_fn(torch.float8_e4m3fn, torch.bfloat16)
1370+
13411371
def test_layerwise_casting_inference(self):
13421372
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS
13431373

tests/models/unets/test_models_unet_1d.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@ def test_ema_training(self):
6060
def test_training(self):
6161
pass
6262

63+
@unittest.skip("Test not supported.")
64+
def test_layerwise_casting_training(self):
65+
pass
66+
6367
def test_determinism(self):
6468
super().test_determinism()
6569

@@ -239,6 +243,10 @@ def test_ema_training(self):
239243
def test_training(self):
240244
pass
241245

246+
@unittest.skip("Test not supported.")
247+
def test_layerwise_casting_training(self):
248+
pass
249+
242250
def prepare_init_args_and_inputs_for_common(self):
243251
init_dict = {
244252
"in_channels": 14,

0 commit comments

Comments
 (0)