Skip to content

Commit 64814c9

Browse files
authored
Merge branch 'main' into flux-control-lora-training-script
2 parents 02016ca + 25f3e91 commit 64814c9

File tree

5 files changed

+278
-136
lines changed

5 files changed

+278
-136
lines changed

.github/workflows/pr_test_peft_backend.yml

Lines changed: 0 additions & 134 deletions
This file was deleted.

.github/workflows/pr_tests.yml

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,3 +234,67 @@ jobs:
234234
with:
235235
name: pr_${{ matrix.config.report }}_test_reports
236236
path: reports
237+
238+
run_lora_tests:
239+
needs: [check_code_quality, check_repository_consistency]
240+
strategy:
241+
fail-fast: false
242+
243+
name: LoRA tests with PEFT main
244+
245+
runs-on:
246+
group: aws-general-8-plus
247+
248+
container:
249+
image: diffusers/diffusers-pytorch-cpu
250+
options: --shm-size "16gb" --ipc host -v /mnt/hf_cache:/mnt/cache/
251+
252+
defaults:
253+
run:
254+
shell: bash
255+
256+
steps:
257+
- name: Checkout diffusers
258+
uses: actions/checkout@v3
259+
with:
260+
fetch-depth: 2
261+
262+
- name: Install dependencies
263+
run: |
264+
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
265+
python -m uv pip install -e [quality,test]
266+
# TODO (sayakpaul, DN6): revisit `--no-deps`
267+
python -m pip install -U peft@git+https://github.com/huggingface/peft.git --no-deps
268+
python -m uv pip install -U transformers@git+https://github.com/huggingface/transformers.git --no-deps
269+
pip uninstall accelerate -y && python -m uv pip install -U accelerate@git+https://github.com/huggingface/accelerate.git --no-deps
270+
271+
- name: Environment
272+
run: |
273+
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
274+
python utils/print_env.py
275+
276+
- name: Run fast PyTorch LoRA tests with PEFT
277+
run: |
278+
python -m venv /opt/venv && export PATH="/opt/venv/bin:$PATH"
279+
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
280+
-s -v \
281+
--make-reports=tests_peft_main \
282+
tests/lora/
283+
python -m pytest -n 4 --max-worker-restart=0 --dist=loadfile \
284+
-s -v \
285+
--make-reports=tests_models_lora_peft_main \
286+
tests/models/ -k "lora"
287+
288+
- name: Failure short reports
289+
if: ${{ failure() }}
290+
run: |
291+
cat reports/tests_lora_failures_short.txt
292+
cat reports/tests_models_lora_failures_short.txt
293+
294+
- name: Test suite reports artifacts
295+
if: ${{ always() }}
296+
uses: actions/upload-artifact@v4
297+
with:
298+
name: pr_main_test_reports
299+
path: reports
300+

src/diffusers/models/attention_processor.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,14 @@ def set_use_memory_efficient_attention_xformers(
358358
self.processor,
359359
(IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
360360
)
361+
is_joint_processor = hasattr(self, "processor") and isinstance(
362+
self.processor,
363+
(
364+
JointAttnProcessor2_0,
365+
XFormersJointAttnProcessor,
366+
),
367+
)
368+
361369
if use_memory_efficient_attention_xformers:
362370
if is_added_kv_processor and is_custom_diffusion:
363371
raise NotImplementedError(
@@ -420,6 +428,8 @@ def set_use_memory_efficient_attention_xformers(
420428
processor.to(
421429
device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype
422430
)
431+
elif is_joint_processor:
432+
processor = XFormersJointAttnProcessor(attention_op=attention_op)
423433
else:
424434
processor = XFormersAttnProcessor(attention_op=attention_op)
425435
else:
@@ -1685,6 +1695,91 @@ def __call__(
16851695
return hidden_states, encoder_hidden_states
16861696

16871697

1698+
class XFormersJointAttnProcessor:
1699+
r"""
1700+
Processor for implementing memory efficient attention using xFormers.
1701+
1702+
Args:
1703+
attention_op (`Callable`, *optional*, defaults to `None`):
1704+
The base
1705+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1706+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1707+
operator.
1708+
"""
1709+
1710+
def __init__(self, attention_op: Optional[Callable] = None):
1711+
self.attention_op = attention_op
1712+
1713+
def __call__(
1714+
self,
1715+
attn: Attention,
1716+
hidden_states: torch.FloatTensor,
1717+
encoder_hidden_states: torch.FloatTensor = None,
1718+
attention_mask: Optional[torch.FloatTensor] = None,
1719+
*args,
1720+
**kwargs,
1721+
) -> torch.FloatTensor:
1722+
residual = hidden_states
1723+
1724+
# `sample` projections.
1725+
query = attn.to_q(hidden_states)
1726+
key = attn.to_k(hidden_states)
1727+
value = attn.to_v(hidden_states)
1728+
1729+
query = attn.head_to_batch_dim(query).contiguous()
1730+
key = attn.head_to_batch_dim(key).contiguous()
1731+
value = attn.head_to_batch_dim(value).contiguous()
1732+
1733+
if attn.norm_q is not None:
1734+
query = attn.norm_q(query)
1735+
if attn.norm_k is not None:
1736+
key = attn.norm_k(key)
1737+
1738+
# `context` projections.
1739+
if encoder_hidden_states is not None:
1740+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1741+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1742+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1743+
1744+
encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous()
1745+
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous()
1746+
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous()
1747+
1748+
if attn.norm_added_q is not None:
1749+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1750+
if attn.norm_added_k is not None:
1751+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1752+
1753+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1754+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1755+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1756+
1757+
hidden_states = xformers.ops.memory_efficient_attention(
1758+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1759+
)
1760+
hidden_states = hidden_states.to(query.dtype)
1761+
hidden_states = attn.batch_to_head_dim(hidden_states)
1762+
1763+
if encoder_hidden_states is not None:
1764+
# Split the attention outputs.
1765+
hidden_states, encoder_hidden_states = (
1766+
hidden_states[:, : residual.shape[1]],
1767+
hidden_states[:, residual.shape[1] :],
1768+
)
1769+
if not attn.context_pre_only:
1770+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1771+
1772+
# linear proj
1773+
hidden_states = attn.to_out[0](hidden_states)
1774+
# dropout
1775+
hidden_states = attn.to_out[1](hidden_states)
1776+
1777+
if encoder_hidden_states is not None:
1778+
return hidden_states, encoder_hidden_states
1779+
else:
1780+
return hidden_states
1781+
1782+
16881783
class AllegroAttnProcessor2_0:
16891784
r"""
16901785
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is

0 commit comments

Comments
 (0)