Skip to content

Commit 263422a

Browse files
authored
Merge branch 'main' into lora-bnb-tests-peft
2 parents 9e9f057 + be54a95 commit 263422a

File tree

7 files changed

+252
-65
lines changed

7 files changed

+252
-65
lines changed

examples/controlnet/train_controlnet_sd3.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,8 +1283,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
12831283
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
12841284

12851285
# Get the text embedding for conditioning
1286-
prompt_embeds = batch["prompt_embeds"]
1287-
pooled_prompt_embeds = batch["pooled_prompt_embeds"]
1286+
prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype)
1287+
pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype)
12881288

12891289
# controlnet(s) inference
12901290
controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)

examples/research_projects/pytorch_xla/inference/flux/README.md

Lines changed: 109 additions & 44 deletions
Large diffs are not rendered by default.

examples/research_projects/pytorch_xla/inference/flux/flux_inference.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import torch_xla.debug.profiler as xp
1010
import torch_xla.distributed.xla_multiprocessing as xmp
1111
import torch_xla.runtime as xr
12+
from torch_xla.experimental.custom_kernel import FlashAttention
1213

1314
from diffusers import FluxPipeline
1415

@@ -36,6 +37,19 @@ def _main(index, args, text_pipe, ckpt_id):
3637
ckpt_id, text_encoder=None, tokenizer=None, text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16
3738
).to(device0)
3839
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
40+
FlashAttention.DEFAULT_BLOCK_SIZES = {
41+
"block_q": 1536,
42+
"block_k_major": 1536,
43+
"block_k": 1536,
44+
"block_b": 1536,
45+
"block_q_major_dkv": 1536,
46+
"block_k_major_dkv": 1536,
47+
"block_q_dkv": 1536,
48+
"block_k_dkv": 1536,
49+
"block_q_dq": 1536,
50+
"block_k_dq": 1536,
51+
"block_k_major_dq": 1536,
52+
}
3953

4054
prompt = "photograph of an electronics chip in the shape of a race car with trillium written on its side"
4155
width = args.width
@@ -69,14 +83,14 @@ def _main(index, args, text_pipe, ckpt_id):
6983
xm.set_rng_state(seed=unique_seed, device=device0)
7084
times = []
7185
logger.info("starting inference run...")
86+
with torch.no_grad():
87+
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
88+
prompt=prompt, prompt_2=None, max_sequence_length=512
89+
)
90+
prompt_embeds = prompt_embeds.to(device0)
91+
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
7292
for _ in range(args.itters):
7393
ts = perf_counter()
74-
with torch.no_grad():
75-
prompt_embeds, pooled_prompt_embeds, text_ids = text_pipe.encode_prompt(
76-
prompt=prompt, prompt_2=None, max_sequence_length=512
77-
)
78-
prompt_embeds = prompt_embeds.to(device0)
79-
pooled_prompt_embeds = pooled_prompt_embeds.to(device0)
8094

8195
if args.profile:
8296
xp.trace_detached(f"localhost:{profiler_port}", str(profile_path), duration_ms=profile_duration)
@@ -92,7 +106,7 @@ def _main(index, args, text_pipe, ckpt_id):
92106
if index == 0:
93107
logger.info(f"inference time: {inference_time}")
94108
times.append(inference_time)
95-
logger.info(f"avg. inference over {args.itters} iterations took {sum(times)/len(times)} sec.")
109+
logger.info(f"avg. inference over {args.itters} iterations took {sum(times) / len(times)} sec.")
96110
image.save(f"/tmp/inference_out-{index}.png")
97111
if index == 0:
98112
metrics_report = met.metrics_report()

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1355,6 +1355,7 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
13551355
original_state_dict = {k[len("diffusion_model.") :]: v for k, v in state_dict.items()}
13561356

13571357
num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in original_state_dict})
1358+
is_i2v_lora = any("k_img" in k for k in original_state_dict) and any("v_img" in k for k in original_state_dict)
13581359

13591360
for i in range(num_blocks):
13601361
# Self-attention
@@ -1374,13 +1375,15 @@ def _convert_non_diffusers_wan_lora_to_diffusers(state_dict):
13741375
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
13751376
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
13761377
)
1377-
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1378-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1379-
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1380-
)
1381-
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1382-
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1383-
)
1378+
1379+
if is_i2v_lora:
1380+
for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
1381+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_A.weight"] = original_state_dict.pop(
1382+
f"blocks.{i}.cross_attn.{o}.lora_A.weight"
1383+
)
1384+
converted_state_dict[f"blocks.{i}.attn2.{c}.lora_B.weight"] = original_state_dict.pop(
1385+
f"blocks.{i}.cross_attn.{o}.lora_B.weight"
1386+
)
13841387

13851388
# FFN
13861389
for o, c in zip(["ffn.0", "ffn.2"], ["net.0.proj", "net.2"]):

src/diffusers/models/attention_processor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2339,7 +2339,9 @@ def __call__(
23392339
query = apply_rotary_emb(query, image_rotary_emb)
23402340
key = apply_rotary_emb(key, image_rotary_emb)
23412341

2342-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2342+
hidden_states = F.scaled_dot_product_attention(
2343+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2344+
)
23432345

23442346
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
23452347
hidden_states = hidden_states.to(query.dtype)

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1610,7 +1610,7 @@ def _get_signature_keys(cls, obj):
16101610
expected_modules.add(name)
16111611
optional_parameters.remove(name)
16121612

1613-
return expected_modules, optional_parameters
1613+
return sorted(expected_modules), sorted(optional_parameters)
16141614

16151615
@classmethod
16161616
def _get_signature_types(cls):
@@ -1652,10 +1652,12 @@ def components(self) -> Dict[str, Any]:
16521652
k: getattr(self, k) for k in self.config.keys() if not k.startswith("_") and k not in optional_parameters
16531653
}
16541654

1655-
if set(components.keys()) != expected_modules:
1655+
actual = sorted(set(components.keys()))
1656+
expected = sorted(expected_modules)
1657+
if actual != expected:
16561658
raise ValueError(
16571659
f"{self} has been incorrectly initialized or {self.__class__} is incorrectly implemented. Expected"
1658-
f" {expected_modules} to be defined, but {components.keys()} are defined."
1660+
f" {expected} to be defined, but {actual} are defined."
16591661
)
16601662

16611663
return components

tests/pipelines/test_pipeline_utils.py

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
UNet2DConditionModel,
2020
)
2121
from diffusers.pipelines.pipeline_loading_utils import is_safetensors_compatible, variant_compatible_siblings
22-
from diffusers.utils.testing_utils import torch_device
22+
from diffusers.utils.testing_utils import require_torch_gpu, torch_device
2323

2424

2525
class IsSafetensorsCompatibleTests(unittest.TestCase):
@@ -826,3 +826,104 @@ def test_video_to_video(self):
826826
with io.StringIO() as stderr, contextlib.redirect_stderr(stderr):
827827
_ = pipe(**inputs)
828828
self.assertTrue(stderr.getvalue() == "", "Progress bar should be disabled")
829+
830+
831+
@require_torch_gpu
832+
class PipelineDeviceAndDtypeStabilityTests(unittest.TestCase):
833+
expected_pipe_device = torch.device("cuda:0")
834+
expected_pipe_dtype = torch.float64
835+
836+
def get_dummy_components_image_generation(self):
837+
cross_attention_dim = 8
838+
839+
torch.manual_seed(0)
840+
unet = UNet2DConditionModel(
841+
block_out_channels=(4, 8),
842+
layers_per_block=1,
843+
sample_size=32,
844+
in_channels=4,
845+
out_channels=4,
846+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
847+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
848+
cross_attention_dim=cross_attention_dim,
849+
norm_num_groups=2,
850+
)
851+
scheduler = DDIMScheduler(
852+
beta_start=0.00085,
853+
beta_end=0.012,
854+
beta_schedule="scaled_linear",
855+
clip_sample=False,
856+
set_alpha_to_one=False,
857+
)
858+
torch.manual_seed(0)
859+
vae = AutoencoderKL(
860+
block_out_channels=[4, 8],
861+
in_channels=3,
862+
out_channels=3,
863+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
864+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
865+
latent_channels=4,
866+
norm_num_groups=2,
867+
)
868+
torch.manual_seed(0)
869+
text_encoder_config = CLIPTextConfig(
870+
bos_token_id=0,
871+
eos_token_id=2,
872+
hidden_size=cross_attention_dim,
873+
intermediate_size=16,
874+
layer_norm_eps=1e-05,
875+
num_attention_heads=2,
876+
num_hidden_layers=2,
877+
pad_token_id=1,
878+
vocab_size=1000,
879+
)
880+
text_encoder = CLIPTextModel(text_encoder_config)
881+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
882+
883+
components = {
884+
"unet": unet,
885+
"scheduler": scheduler,
886+
"vae": vae,
887+
"text_encoder": text_encoder,
888+
"tokenizer": tokenizer,
889+
"safety_checker": None,
890+
"feature_extractor": None,
891+
"image_encoder": None,
892+
}
893+
return components
894+
895+
def test_deterministic_device(self):
896+
components = self.get_dummy_components_image_generation()
897+
898+
pipe = StableDiffusionPipeline(**components)
899+
pipe.to(device=torch_device, dtype=torch.float32)
900+
901+
pipe.unet.to(device="cpu")
902+
pipe.vae.to(device="cuda")
903+
pipe.text_encoder.to(device="cuda:0")
904+
905+
pipe_device = pipe.device
906+
907+
self.assertEqual(
908+
self.expected_pipe_device,
909+
pipe_device,
910+
f"Wrong expected device. Expected {self.expected_pipe_device}. Got {pipe_device}.",
911+
)
912+
913+
def test_deterministic_dtype(self):
914+
components = self.get_dummy_components_image_generation()
915+
916+
pipe = StableDiffusionPipeline(**components)
917+
pipe.to(device=torch_device, dtype=torch.float32)
918+
919+
pipe.unet.to(dtype=torch.float16)
920+
pipe.vae.to(dtype=torch.float32)
921+
pipe.text_encoder.to(dtype=torch.float64)
922+
923+
pipe_dtype = pipe.dtype
924+
925+
self.assertEqual(
926+
self.expected_pipe_dtype,
927+
pipe_dtype,
928+
f"Wrong expected dtype. Expected {self.expected_pipe_dtype}. Got {pipe_dtype}.",
929+
)

0 commit comments

Comments
 (0)