Skip to content

Commit 81f4785

Browse files
authored
Merge branch 'main' into add-trtquant-backend
2 parents d66709b + 0ff1aa9 commit 81f4785

File tree

10 files changed

+1188
-26
lines changed

10 files changed

+1188
-26
lines changed

docs/source/en/api/pipelines/qwenimage.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@ The `guidance_scale` parameter in the pipeline is there to support future guidan
120120
- all
121121
- __call__
122122

123+
## QwenImageEditInpaintPipeline
124+
125+
[[autodoc]] QwenImageEditInpaintPipeline
126+
- all
127+
- __call__
128+
123129
## QwenImaggeControlNetPipeline
124130
- all
125131
- __call__

examples/dreambooth/train_dreambooth_lora_flux_kontext.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1270,6 +1270,7 @@ def main(args):
12701270
subfolder="transformer",
12711271
revision=args.revision,
12721272
variant=args.variant,
1273+
torch_dtype=torch_dtype,
12731274
)
12741275
pipeline = FluxKontextPipeline.from_pretrained(
12751276
args.pretrained_model_name_or_path,
@@ -1292,7 +1293,8 @@ def main(args):
12921293
for example in tqdm(
12931294
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
12941295
):
1295-
images = pipeline(example["prompt"]).images
1296+
with torch.autocast(device_type=accelerator.device.type, dtype=torch_dtype):
1297+
images = pipeline(prompt=example["prompt"]).images
12961298

12971299
for i, image in enumerate(images):
12981300
hash_image = insecure_hashlib.sha1(image.tobytes()).hexdigest()
@@ -1899,6 +1901,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
18991901
device=accelerator.device,
19001902
prompt=args.instance_prompt,
19011903
)
1904+
else:
1905+
prompt_embeds, pooled_prompt_embeds, text_ids = compute_text_embeddings(
1906+
prompts, text_encoders, tokenizers
1907+
)
19021908

19031909
# Convert images to latent space
19041910
if args.cache_latents:

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -507,6 +507,7 @@
507507
"PixArtSigmaPAGPipeline",
508508
"PixArtSigmaPipeline",
509509
"QwenImageControlNetPipeline",
510+
"QwenImageEditInpaintPipeline",
510511
"QwenImageEditPipeline",
511512
"QwenImageImg2ImgPipeline",
512513
"QwenImageInpaintPipeline",
@@ -1155,6 +1156,7 @@
11551156
PixArtSigmaPAGPipeline,
11561157
PixArtSigmaPipeline,
11571158
QwenImageControlNetPipeline,
1159+
QwenImageEditInpaintPipeline,
11581160
QwenImageEditPipeline,
11591161
QwenImageImg2ImgPipeline,
11601162
QwenImageInpaintPipeline,

src/diffusers/loaders/lora_conversion_utils.py

Lines changed: 40 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2129,6 +2129,10 @@ def _convert_non_diffusers_ltxv_lora_to_diffusers(state_dict, non_diffusers_pref
21292129

21302130

21312131
def _convert_non_diffusers_qwen_lora_to_diffusers(state_dict):
2132+
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
2133+
if has_diffusion_model:
2134+
state_dict = {k.removeprefix("diffusion_model."): v for k, v in state_dict.items()}
2135+
21322136
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
21332137
if has_lora_unet:
21342138
state_dict = {k.removeprefix("lora_unet_"): v for k, v in state_dict.items()}
@@ -2201,29 +2205,44 @@ def convert_key(key: str) -> str:
22012205
all_keys = list(state_dict.keys())
22022206
down_key = ".lora_down.weight"
22032207
up_key = ".lora_up.weight"
2208+
a_key = ".lora_A.weight"
2209+
b_key = ".lora_B.weight"
22042210

2205-
def get_alpha_scales(down_weight, alpha_key):
2206-
rank = down_weight.shape[0]
2207-
alpha = state_dict.pop(alpha_key).item()
2208-
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2209-
scale_down = scale
2210-
scale_up = 1.0
2211-
while scale_down * 2 < scale_up:
2212-
scale_down *= 2
2213-
scale_up /= 2
2214-
return scale_down, scale_up
2211+
has_non_diffusers_lora_id = any(down_key in k or up_key in k for k in all_keys)
2212+
has_diffusers_lora_id = any(a_key in k or b_key in k for k in all_keys)
22152213

2216-
for k in all_keys:
2217-
if k.endswith(down_key):
2218-
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2219-
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2220-
alpha_key = k.replace(down_key, ".alpha")
2221-
2222-
down_weight = state_dict.pop(k)
2223-
up_weight = state_dict.pop(k.replace(down_key, up_key))
2224-
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2225-
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2226-
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2214+
if has_non_diffusers_lora_id:
2215+
2216+
def get_alpha_scales(down_weight, alpha_key):
2217+
rank = down_weight.shape[0]
2218+
alpha = state_dict.pop(alpha_key).item()
2219+
scale = alpha / rank # LoRA is scaled by 'alpha / rank' in forward pass, so we need to scale it back here
2220+
scale_down = scale
2221+
scale_up = 1.0
2222+
while scale_down * 2 < scale_up:
2223+
scale_down *= 2
2224+
scale_up /= 2
2225+
return scale_down, scale_up
2226+
2227+
for k in all_keys:
2228+
if k.endswith(down_key):
2229+
diffusers_down_key = k.replace(down_key, ".lora_A.weight")
2230+
diffusers_up_key = k.replace(down_key, up_key).replace(up_key, ".lora_B.weight")
2231+
alpha_key = k.replace(down_key, ".alpha")
2232+
2233+
down_weight = state_dict.pop(k)
2234+
up_weight = state_dict.pop(k.replace(down_key, up_key))
2235+
scale_down, scale_up = get_alpha_scales(down_weight, alpha_key)
2236+
converted_state_dict[diffusers_down_key] = down_weight * scale_down
2237+
converted_state_dict[diffusers_up_key] = up_weight * scale_up
2238+
2239+
# Already in diffusers format (lora_A/lora_B), just pop
2240+
elif has_diffusers_lora_id:
2241+
for k in all_keys:
2242+
if a_key in k or b_key in k:
2243+
converted_state_dict[k] = state_dict.pop(k)
2244+
elif ".alpha" in k:
2245+
state_dict.pop(k)
22272246

22282247
if len(state_dict) > 0:
22292248
raise ValueError(f"`state_dict` should be empty at this point but has {state_dict.keys()=}")

src/diffusers/loaders/lora_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6684,7 +6684,8 @@ def lora_state_dict(
66846684

66856685
has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
66866686
has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
6687-
if has_alphas_in_sd or has_lora_unet:
6687+
has_diffusion_model = any(k.startswith("diffusion_model.") for k in state_dict)
6688+
if has_alphas_in_sd or has_lora_unet or has_diffusion_model:
66886689
state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
66896690

66906691
out = (state_dict, metadata) if return_lora_metadata else state_dict

src/diffusers/models/attention_dispatch.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -955,12 +955,13 @@ def _native_npu_attention(
955955
dropout_p: float = 0.0,
956956
scale: Optional[float] = None,
957957
) -> torch.Tensor:
958-
return npu_fusion_attention(
958+
query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
959+
out = npu_fusion_attention(
959960
query,
960961
key,
961962
value,
962-
query.size(2), # num_heads
963-
input_layout="BSND",
963+
query.size(1), # num_heads
964+
input_layout="BNSD",
964965
pse=None,
965966
scale=1.0 / math.sqrt(query.shape[-1]) if scale is None else scale,
966967
pre_tockens=65536,
@@ -969,6 +970,8 @@ def _native_npu_attention(
969970
sync=False,
970971
inner_precise=0,
971972
)[0]
973+
out = out.transpose(1, 2).contiguous()
974+
return out
972975

973976

974977
# Reference: https://github.com/pytorch/xla/blob/06c5533de6588f6b90aa1655d9850bcf733b90b4/torch_xla/experimental/custom_kernel.py#L853

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,7 @@
393393
"QwenImageImg2ImgPipeline",
394394
"QwenImageInpaintPipeline",
395395
"QwenImageEditPipeline",
396+
"QwenImageEditInpaintPipeline",
396397
"QwenImageControlNetPipeline",
397398
]
398399
try:
@@ -714,6 +715,7 @@
714715
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
715716
from .qwenimage import (
716717
QwenImageControlNetPipeline,
718+
QwenImageEditInpaintPipeline,
717719
QwenImageEditPipeline,
718720
QwenImageImg2ImgPipeline,
719721
QwenImageInpaintPipeline,

src/diffusers/pipelines/qwenimage/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
_import_structure["pipeline_qwenimage"] = ["QwenImagePipeline"]
2727
_import_structure["pipeline_qwenimage_controlnet"] = ["QwenImageControlNetPipeline"]
2828
_import_structure["pipeline_qwenimage_edit"] = ["QwenImageEditPipeline"]
29+
_import_structure["pipeline_qwenimage_edit_inpaint"] = ["QwenImageEditInpaintPipeline"]
2930
_import_structure["pipeline_qwenimage_img2img"] = ["QwenImageImg2ImgPipeline"]
3031
_import_structure["pipeline_qwenimage_inpaint"] = ["QwenImageInpaintPipeline"]
3132

@@ -39,6 +40,7 @@
3940
from .pipeline_qwenimage import QwenImagePipeline
4041
from .pipeline_qwenimage_controlnet import QwenImageControlNetPipeline
4142
from .pipeline_qwenimage_edit import QwenImageEditPipeline
43+
from .pipeline_qwenimage_edit_inpaint import QwenImageEditInpaintPipeline
4244
from .pipeline_qwenimage_img2img import QwenImageImg2ImgPipeline
4345
from .pipeline_qwenimage_inpaint import QwenImageInpaintPipeline
4446
else:

0 commit comments

Comments
 (0)