Skip to content

Commit 7068529

Browse files
committed
Merge branch 'main' into tests-encode-prompt
2 parents b1c9666 + f8b54cf commit 7068529

23 files changed

+1065
-528
lines changed

.github/workflows/pr_style_bot.yml

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
name: PR Style Bot
2+
3+
on:
4+
issue_comment:
5+
types: [created]
6+
7+
permissions:
8+
contents: write
9+
pull-requests: write
10+
11+
jobs:
12+
run-style-bot:
13+
if: >
14+
contains(github.event.comment.body, '@bot /style') &&
15+
github.event.issue.pull_request != null
16+
runs-on: ubuntu-latest
17+
18+
steps:
19+
- name: Extract PR details
20+
id: pr_info
21+
uses: actions/github-script@v6
22+
with:
23+
script: |
24+
const prNumber = context.payload.issue.number;
25+
const { data: pr } = await github.rest.pulls.get({
26+
owner: context.repo.owner,
27+
repo: context.repo.repo,
28+
pull_number: prNumber
29+
});
30+
31+
// We capture both the branch ref and the "full_name" of the head repo
32+
// so that we can check out the correct repository & branch (including forks).
33+
core.setOutput("prNumber", prNumber);
34+
core.setOutput("headRef", pr.head.ref);
35+
core.setOutput("headRepoFullName", pr.head.repo.full_name);
36+
37+
- name: Check out PR branch
38+
uses: actions/checkout@v3
39+
env:
40+
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
41+
HEADREF: ${{ steps.pr_info.outputs.headRef }}
42+
with:
43+
# Instead of checking out the base repo, use the contributor's repo name
44+
repository: ${{ env.HEADREPOFULLNAME }}
45+
ref: ${{ env.HEADREF }}
46+
# You may need fetch-depth: 0 for being able to push
47+
fetch-depth: 0
48+
token: ${{ secrets.GITHUB_TOKEN }}
49+
50+
- name: Debug
51+
env:
52+
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
53+
HEADREF: ${{ steps.pr_info.outputs.headRef }}
54+
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
55+
run: |
56+
echo "PR number: ${{ env.PRNUMBER }}"
57+
echo "Head Ref: ${{ env.HEADREF }}"
58+
echo "Head Repo Full Name: ${{ env.HEADREPOFULLNAME }}"
59+
60+
- name: Set up Python
61+
uses: actions/setup-python@v4
62+
63+
- name: Install dependencies
64+
run: |
65+
pip install .[quality]
66+
67+
- name: Download Makefile from main branch
68+
run: |
69+
curl -o main_Makefile https://raw.githubusercontent.com/huggingface/diffusers/main/Makefile
70+
71+
- name: Compare Makefiles
72+
run: |
73+
if ! diff -q main_Makefile Makefile; then
74+
echo "Error: The Makefile has changed. Please ensure it matches the main branch."
75+
exit 1
76+
fi
77+
echo "No changes in Makefile. Proceeding..."
78+
rm -rf main_Makefile
79+
80+
- name: Run make style and make quality
81+
run: |
82+
make style && make quality
83+
84+
- name: Commit and push changes
85+
id: commit_and_push
86+
env:
87+
HEADREPOFULLNAME: ${{ steps.pr_info.outputs.headRepoFullName }}
88+
HEADREF: ${{ steps.pr_info.outputs.headRef }}
89+
PRNUMBER: ${{ steps.pr_info.outputs.prNumber }}
90+
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
91+
run: |
92+
echo "HEADREPOFULLNAME: ${{ env.HEADREPOFULLNAME }}, HEADREF: ${{ env.HEADREF }}"
93+
# Configure git with the Actions bot user
94+
git config user.name "github-actions[bot]"
95+
git config user.email "github-actions[bot]@users.noreply.github.com"
96+
97+
# Make sure your 'origin' remote is set to the contributor's fork
98+
git remote set-url origin "https://x-access-token:${GITHUB_TOKEN}@github.com/${{ env.HEADREPOFULLNAME }}.git"
99+
100+
# If there are changes after running style/quality, commit them
101+
if [ -n "$(git status --porcelain)" ]; then
102+
git add .
103+
git commit -m "Apply style fixes"
104+
# Push to the original contributor's forked branch
105+
git push origin HEAD:${{ env.HEADREF }}
106+
echo "changes_pushed=true" >> $GITHUB_OUTPUT
107+
else
108+
echo "No changes to commit."
109+
echo "changes_pushed=false" >> $GITHUB_OUTPUT
110+
fi
111+
112+
- name: Comment on PR with workflow run link
113+
if: steps.commit_and_push.outputs.changes_pushed == 'true'
114+
uses: actions/github-script@v6
115+
with:
116+
script: |
117+
const prNumber = parseInt(process.env.prNumber, 10);
118+
const runUrl = `${process.env.GITHUB_SERVER_URL}/${process.env.GITHUB_REPOSITORY}/actions/runs/${process.env.GITHUB_RUN_ID}`
119+
120+
await github.rest.issues.createComment({
121+
owner: context.repo.owner,
122+
repo: context.repo.repo,
123+
issue_number: prNumber,
124+
body: `Style fixes have been applied. [View the workflow run here](${runUrl}).`
125+
});
126+
env:
127+
prNumber: ${{ steps.pr_info.outputs.prNumber }}

.github/workflows/pr_tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@ name: Fast tests for PRs
22

33
on:
44
pull_request:
5-
branches:
6-
- main
5+
branches: [main]
6+
types: [synchronize]
77
paths:
88
- "src/diffusers/**.py"
99
- "benchmarks/**.py"

src/diffusers/loaders/lora_base.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -661,8 +661,20 @@ def set_adapters(
661661
adapter_names: Union[List[str], str],
662662
adapter_weights: Optional[Union[float, Dict, List[float], List[Dict]]] = None,
663663
):
664-
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
664+
if isinstance(adapter_weights, dict):
665+
components_passed = set(adapter_weights.keys())
666+
lora_components = set(self._lora_loadable_modules)
667+
668+
invalid_components = sorted(components_passed - lora_components)
669+
if invalid_components:
670+
logger.warning(
671+
f"The following components in `adapter_weights` are not part of the pipeline: {invalid_components}. "
672+
f"Available components that are LoRA-compatible: {self._lora_loadable_modules}. So, weights belonging "
673+
"to the invalid components will be removed and ignored."
674+
)
675+
adapter_weights = {k: v for k, v in adapter_weights.items() if k not in invalid_components}
665676

677+
adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
666678
adapter_weights = copy.deepcopy(adapter_weights)
667679

668680
# Expand weights into a list, one entry per adapter
@@ -697,12 +709,6 @@ def set_adapters(
697709
for adapter_name, weights in zip(adapter_names, adapter_weights):
698710
if isinstance(weights, dict):
699711
component_adapter_weights = weights.pop(component, None)
700-
701-
if component_adapter_weights is not None and not hasattr(self, component):
702-
logger.warning(
703-
f"Lora weight dict contains {component} weights but will be ignored because pipeline does not have {component}."
704-
)
705-
706712
if component_adapter_weights is not None and component not in invert_list_adapters[adapter_name]:
707713
logger.warning(
708714
(

src/diffusers/loaders/single_file.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from huggingface_hub import snapshot_download
2020
from huggingface_hub.utils import LocalEntryNotFoundError, validate_hf_hub_args
2121
from packaging import version
22+
from typing_extensions import Self
2223

2324
from ..utils import deprecate, is_transformers_available, logging
2425
from .single_file_utils import (
@@ -269,7 +270,7 @@ class FromSingleFileMixin:
269270

270271
@classmethod
271272
@validate_hf_hub_args
272-
def from_single_file(cls, pretrained_model_link_or_path, **kwargs):
273+
def from_single_file(cls, pretrained_model_link_or_path, **kwargs) -> Self:
273274
r"""
274275
Instantiate a [`DiffusionPipeline`] from pretrained pipeline weights saved in the `.ckpt` or `.safetensors`
275276
format. The pipeline is set in evaluation mode (`model.eval()`) by default.

src/diffusers/loaders/single_file_model.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import torch
2121
from huggingface_hub.utils import validate_hf_hub_args
22+
from typing_extensions import Self
2223

2324
from ..quantizers import DiffusersAutoQuantizer
2425
from ..utils import deprecate, is_accelerate_available, logging
@@ -51,7 +52,7 @@
5152

5253

5354
if is_accelerate_available():
54-
from accelerate import init_empty_weights
55+
from accelerate import dispatch_model, init_empty_weights
5556

5657
from ..models.modeling_utils import load_model_dict_into_meta
5758

@@ -148,7 +149,7 @@ class FromOriginalModelMixin:
148149

149150
@classmethod
150151
@validate_hf_hub_args
151-
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs):
152+
def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] = None, **kwargs) -> Self:
152153
r"""
153154
Instantiate a model from pretrained weights saved in the original `.ckpt` or `.safetensors` format. The model
154155
is set in evaluation mode (`model.eval()`) by default.
@@ -365,19 +366,23 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
365366
keep_in_fp32_modules=keep_in_fp32_modules,
366367
)
367368

369+
device_map = None
368370
if is_accelerate_available():
369371
param_device = torch.device(device) if device else torch.device("cpu")
370-
named_buffers = model.named_buffers()
371-
unexpected_keys = load_model_dict_into_meta(
372+
empty_state_dict = model.state_dict()
373+
unexpected_keys = [
374+
param_name for param_name in diffusers_format_checkpoint if param_name not in empty_state_dict
375+
]
376+
device_map = {"": param_device}
377+
load_model_dict_into_meta(
372378
model,
373379
diffusers_format_checkpoint,
374380
dtype=torch_dtype,
375-
device=param_device,
381+
device_map=device_map,
376382
hf_quantizer=hf_quantizer,
377383
keep_in_fp32_modules=keep_in_fp32_modules,
378-
named_buffers=named_buffers,
384+
unexpected_keys=unexpected_keys,
379385
)
380-
381386
else:
382387
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
383388

@@ -399,4 +404,8 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
399404

400405
model.eval()
401406

407+
if device_map is not None:
408+
device_map_kwargs = {"device_map": device_map}
409+
dispatch_model(model, **device_map_kwargs)
410+
402411
return model

src/diffusers/loaders/single_file_utils.py

Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,18 +1593,9 @@ def create_diffusers_clip_model_from_ldm(
15931593
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
15941594

15951595
if is_accelerate_available():
1596-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1596+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
15971597
else:
1598-
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1599-
1600-
if model._keys_to_ignore_on_load_unexpected is not None:
1601-
for pat in model._keys_to_ignore_on_load_unexpected:
1602-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1603-
1604-
if len(unexpected_keys) > 0:
1605-
logger.warning(
1606-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1607-
)
1598+
model.load_state_dict(diffusers_format_checkpoint, strict=False)
16081599

16091600
if torch_dtype is not None:
16101601
model.to(torch_dtype)
@@ -2061,16 +2052,7 @@ def create_diffusers_t5_model_from_checkpoint(
20612052
diffusers_format_checkpoint = convert_sd3_t5_checkpoint_to_diffusers(checkpoint)
20622053

20632054
if is_accelerate_available():
2064-
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
2065-
if model._keys_to_ignore_on_load_unexpected is not None:
2066-
for pat in model._keys_to_ignore_on_load_unexpected:
2067-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
2068-
2069-
if len(unexpected_keys) > 0:
2070-
logger.warning(
2071-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
2072-
)
2073-
2055+
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
20742056
else:
20752057
model.load_state_dict(diffusers_format_checkpoint)
20762058

0 commit comments

Comments
 (0)