Skip to content

Commit 3ae040c

Browse files
authored
Merge branch 'main' into benchmarking-overhaul
2 parents 5792608 + d72184e commit 3ae040c

File tree

16 files changed

+486
-76
lines changed

16 files changed

+486
-76
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,8 @@
180180
title: Caching
181181
- local: optimization/memory
182182
title: Reduce memory usage
183+
- local: optimization/pruna
184+
title: Pruna
183185
- local: optimization/xformers
184186
title: xFormers
185187
- local: optimization/tome
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Pruna
2+
3+
[Pruna](https://github.com/PrunaAI/pruna) is a model optimization framework that offers various optimization methods - quantization, pruning, caching, compilation - for accelerating inference and reducing memory usage. A general overview of the optimization methods are shown below.
4+
5+
6+
| Technique | Description | Speed | Memory | Quality |
7+
|--------------|-----------------------------------------------------------------------------------------------|:-----:|:------:|:-------:|
8+
| `batcher` | Groups multiple inputs together to be processed simultaneously, improving computational efficiency and reducing processing time. ||||
9+
| `cacher` | Stores intermediate results of computations to speed up subsequent operations. ||||
10+
| `compiler` | Optimises the model with instructions for specific hardware. ||||
11+
| `distiller` | Trains a smaller, simpler model to mimic a larger, more complex model. ||||
12+
| `quantizer` | Reduces the precision of weights and activations, lowering memory requirements. ||||
13+
| `pruner` | Removes less important or redundant connections and neurons, resulting in a sparser, more efficient network. ||||
14+
| `recoverer` | Restores the performance of a model after compression. ||||
15+
| `factorizer` | Factorization batches several small matrix multiplications into one large fused operation. ||||
16+
| `enhancer` | Enhances the model output by applying post-processing algorithms such as denoising or upscaling. || - ||
17+
18+
✅ (improves), ➖ (approx. the same), ❌ (worsens)
19+
20+
Explore the full range of optimization methods in the [Pruna documentation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms).
21+
22+
## Installation
23+
24+
Install Pruna with the following command.
25+
26+
```bash
27+
pip install pruna
28+
```
29+
30+
31+
## Optimize Diffusers models
32+
33+
A broad range of optimization algorithms are supported for Diffusers models as shown below.
34+
35+
<div class="flex justify-center">
36+
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/diffusers_combinations.png" alt="Overview of the supported optimization algorithms for diffusers models">
37+
</div>
38+
39+
The example below optimizes [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev)
40+
with a combination of factorizer, compiler, and cacher algorithms. This combination accelerates inference by up to 4.2x and cuts peak GPU memory usage from 34.7GB to 28.0GB, all while maintaining virtually the same output quality.
41+
42+
> [!TIP]
43+
> Refer to the [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html) docs to learn more about the optimization techniques used in this example.
44+
45+
<div class="flex justify-center">
46+
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_combination.png" alt="Optimization techniques used for FLUX.1-dev showing the combination of factorizer, compiler, and cacher algorithms">
47+
</div>
48+
49+
Start by defining a `SmashConfig` with the optimization algorithms to use. To optimize the model, wrap the pipeline and the `SmashConfig` with `smash` and then use the pipeline as normal for inference.
50+
51+
```python
52+
import torch
53+
from diffusers import FluxPipeline
54+
55+
from pruna import PrunaModel, SmashConfig, smash
56+
57+
# load the model
58+
# Try segmind/Segmind-Vega or black-forest-labs/FLUX.1-schnell with a small GPU memory
59+
pipe = FluxPipeline.from_pretrained(
60+
"black-forest-labs/FLUX.1-dev",
61+
torch_dtype=torch.bfloat16
62+
).to("cuda")
63+
64+
# define the configuration
65+
smash_config = SmashConfig()
66+
smash_config["factorizer"] = "qkv_diffusers"
67+
smash_config["compiler"] = "torch_compile"
68+
smash_config["torch_compile_target"] = "module_list"
69+
smash_config["cacher"] = "fora"
70+
smash_config["fora_interval"] = 2
71+
72+
# for the best results in terms of speed you can add these configs
73+
# however they will increase your warmup time from 1.5 min to 10 min
74+
# smash_config["torch_compile_mode"] = "max-autotune-no-cudagraphs"
75+
# smash_config["quantizer"] = "torchao"
76+
# smash_config["torchao_quant_type"] = "fp8dq"
77+
# smash_config["torchao_excluded_modules"] = "norm+embedding"
78+
79+
# optimize the model
80+
smashed_pipe = smash(pipe, smash_config)
81+
82+
# run the model
83+
smashed_pipe("a knitted purple prune").images[0]
84+
```
85+
86+
<div class="flex justify-center">
87+
<img src="https://huggingface.co/datasets/PrunaAI/documentation-images/resolve/main/diffusers/flux_smashed_comparison.png">
88+
</div>
89+
90+
After optimization, we can share and load the optimized model using the Hugging Face Hub.
91+
92+
```python
93+
# save the model
94+
smashed_pipe.save_to_hub("<username>/FLUX.1-dev-smashed")
95+
96+
# load the model
97+
smashed_pipe = PrunaModel.from_hub("<username>/FLUX.1-dev-smashed")
98+
```
99+
100+
## Evaluate and benchmark Diffusers models
101+
102+
Pruna provides the [EvaluationAgent](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html) to evaluate the quality of your optimized models.
103+
104+
We can metrics we care about, such as total time and throughput, and the dataset to evaluate on. We can define a model and pass it to the `EvaluationAgent`.
105+
106+
<hfoptions id="eval">
107+
<hfoption id="optimized model">
108+
109+
We can load and evaluate an optimized model by using the `EvaluationAgent` and pass it to the `Task`.
110+
111+
```python
112+
import torch
113+
from diffusers import FluxPipeline
114+
115+
from pruna import PrunaModel
116+
from pruna.data.pruna_datamodule import PrunaDataModule
117+
from pruna.evaluation.evaluation_agent import EvaluationAgent
118+
from pruna.evaluation.metrics import (
119+
ThroughputMetric,
120+
TorchMetricWrapper,
121+
TotalTimeMetric,
122+
)
123+
from pruna.evaluation.task import Task
124+
125+
# define the device
126+
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
127+
128+
# load the model
129+
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
130+
smashed_pipe = PrunaModel.from_hub("PrunaAI/FLUX.1-dev-smashed")
131+
132+
# Define the metrics
133+
metrics = [
134+
TotalTimeMetric(n_iterations=20, n_warmup_iterations=5),
135+
ThroughputMetric(n_iterations=20, n_warmup_iterations=5),
136+
TorchMetricWrapper("clip"),
137+
]
138+
139+
# Define the datamodule
140+
datamodule = PrunaDataModule.from_string("LAION256")
141+
datamodule.limit_datasets(10)
142+
143+
# Define the task and evaluation agent
144+
task = Task(metrics, datamodule=datamodule, device=device)
145+
eval_agent = EvaluationAgent(task)
146+
147+
# Evaluate smashed model and offload it to CPU
148+
smashed_pipe.move_to_device(device)
149+
smashed_pipe_results = eval_agent.evaluate(smashed_pipe)
150+
smashed_pipe.move_to_device("cpu")
151+
```
152+
153+
</hfoption>
154+
<hfoption id="standalone model">
155+
156+
Instead of comparing the optimized model to the base model, you can also evaluate the standalone `diffusers` model. This is useful if you want to evaluate the performance of the model without the optimization. We can do so by using the `PrunaModel` wrapper and run the `EvaluationAgent` on it.
157+
158+
```python
159+
import torch
160+
from diffusers import FluxPipeline
161+
162+
from pruna import PrunaModel
163+
164+
# load the model
165+
# Try PrunaAI/Segmind-Vega-smashed or PrunaAI/FLUX.1-dev-smashed with a small GPU memory
166+
pipe = FluxPipeline.from_pretrained(
167+
"black-forest-labs/FLUX.1-dev",
168+
torch_dtype=torch.bfloat16
169+
).to("cpu")
170+
wrapped_pipe = PrunaModel(model=pipe)
171+
```
172+
173+
</hfoption>
174+
</hfoptions>
175+
176+
Now that you have seen how to optimize and evaluate your models, you can start using Pruna to optimize your own models. Luckily, we have many examples to help you get started.
177+
178+
> [!TIP]
179+
> For more details about benchmarking Flux, check out the [Announcing FLUX-Juiced: The Fastest Image Generation Endpoint (2.6 times faster)!](https://huggingface.co/blog/PrunaAI/flux-fastest-image-generation-endpoint) blog post and the [InferBench](https://huggingface.co/spaces/PrunaAI/InferBench) Space.
180+
181+
## Reference
182+
183+
- [Pruna](https://github.com/pruna-ai/pruna)
184+
- [Pruna optimization](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/configure.html#configure-algorithms)
185+
- [Pruna evaluation](https://docs.pruna.ai/en/stable/docs_pruna/user_manual/evaluate.html)
186+
- [Pruna tutorials](https://docs.pruna.ai/en/stable/docs_pruna/tutorials/index.html)
187+

examples/advanced_diffusion_training/README_flux.md

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,24 @@ This command will prompt you for a token. Copy-paste yours from your [settings/t
7676
> `pip install wandb`
7777
> Alternatively, you can use other tools / train without reporting by modifying the flag `--report_to="wandb"`.
7878
79+
### LoRA Rank and Alpha
80+
Two key LoRA hyperparameters are LoRA rank and LoRA alpha.
81+
- `--rank`: Defines the dimension of the trainable LoRA matrices. A higher rank means more expressiveness and capacity to learn (and more parameters).
82+
- `--lora_alpha`: A scaling factor for the LoRA's output. The LoRA update is scaled by lora_alpha / lora_rank.
83+
- lora_alpha vs. rank:
84+
This ratio dictates the LoRA's effective strength:
85+
lora_alpha == rank: Scaling factor is 1. The LoRA is applied with its learned strength. (e.g., alpha=16, rank=16)
86+
lora_alpha < rank: Scaling factor < 1. Reduces the LoRA's impact. Useful for subtle changes or to prevent overpowering the base model. (e.g., alpha=8, rank=16)
87+
lora_alpha > rank: Scaling factor > 1. Amplifies the LoRA's impact. Allows a lower rank LoRA to have a stronger effect. (e.g., alpha=32, rank=16)
88+
89+
> [!TIP]
90+
> A common starting point is to set `lora_alpha` equal to `rank`.
91+
> Some also set `lora_alpha` to be twice the `rank` (e.g., lora_alpha=32 for lora_rank=16)
92+
> to give the LoRA updates more influence without increasing parameter count.
93+
> If you find your LoRA is "overcooking" or learning too aggressively, consider setting `lora_alpha` to half of `rank`
94+
> (e.g., lora_alpha=8 for rank=16). Experimentation is often key to finding the optimal balance for your use case.
95+
96+
7997
### Target Modules
8098
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
8199
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore

examples/advanced_diffusion_training/test_dreambooth_lora_flux_advanced.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,16 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import json
1617
import logging
1718
import os
1819
import sys
1920
import tempfile
2021

2122
import safetensors
2223

24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
2326

2427
sys.path.append("..")
2528
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
@@ -281,3 +284,45 @@ def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_mult
281284
run_command(self._launch_args + resume_run_args)
282285

283286
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
287+
288+
def test_dreambooth_lora_with_metadata(self):
289+
# Use a `lora_alpha` that is different from `rank`.
290+
lora_alpha = 8
291+
rank = 4
292+
with tempfile.TemporaryDirectory() as tmpdir:
293+
test_args = f"""
294+
{self.script_path}
295+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
296+
--instance_data_dir {self.instance_data_dir}
297+
--instance_prompt {self.instance_prompt}
298+
--resolution 64
299+
--train_batch_size 1
300+
--gradient_accumulation_steps 1
301+
--max_train_steps 2
302+
--lora_alpha={lora_alpha}
303+
--rank={rank}
304+
--learning_rate 5.0e-04
305+
--scale_lr
306+
--lr_scheduler constant
307+
--lr_warmup_steps 0
308+
--output_dir {tmpdir}
309+
""".split()
310+
311+
run_command(self._launch_args + test_args)
312+
# save_pretrained smoke test
313+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
314+
self.assertTrue(os.path.isfile(state_dict_file))
315+
316+
# Check if the metadata was properly serialized.
317+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
318+
metadata = f.metadata() or {}
319+
320+
metadata.pop("format", None)
321+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
322+
if raw:
323+
raw = json.loads(raw)
324+
325+
loaded_lora_alpha = raw["transformer.lora_alpha"]
326+
self.assertTrue(loaded_lora_alpha == lora_alpha)
327+
loaded_lora_rank = raw["transformer.r"]
328+
self.assertTrue(loaded_lora_rank == rank)

examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
)
5656
from diffusers.optimization import get_scheduler
5757
from diffusers.training_utils import (
58+
_collate_lora_metadata,
5859
_set_state_dict_into_text_encoder,
5960
cast_training_params,
6061
compute_density_for_timestep_sampling,
@@ -431,6 +432,13 @@ def parse_args(input_args=None):
431432
help=("The dimension of the LoRA update matrices."),
432433
)
433434

435+
parser.add_argument(
436+
"--lora_alpha",
437+
type=int,
438+
default=4,
439+
help="LoRA alpha to be used for additional scaling.",
440+
)
441+
434442
parser.add_argument("--lora_dropout", type=float, default=0.0, help="Dropout probability for LoRA layers")
435443

436444
parser.add_argument(
@@ -1556,7 +1564,7 @@ def main(args):
15561564
# now we will add new LoRA weights to the attention layers
15571565
transformer_lora_config = LoraConfig(
15581566
r=args.rank,
1559-
lora_alpha=args.rank,
1567+
lora_alpha=args.lora_alpha,
15601568
lora_dropout=args.lora_dropout,
15611569
init_lora_weights="gaussian",
15621570
target_modules=target_modules,
@@ -1565,7 +1573,7 @@ def main(args):
15651573
if args.train_text_encoder:
15661574
text_lora_config = LoraConfig(
15671575
r=args.rank,
1568-
lora_alpha=args.rank,
1576+
lora_alpha=args.lora_alpha,
15691577
lora_dropout=args.lora_dropout,
15701578
init_lora_weights="gaussian",
15711579
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
@@ -1582,13 +1590,15 @@ def save_model_hook(models, weights, output_dir):
15821590
if accelerator.is_main_process:
15831591
transformer_lora_layers_to_save = None
15841592
text_encoder_one_lora_layers_to_save = None
1585-
1593+
modules_to_save = {}
15861594
for model in models:
15871595
if isinstance(model, type(unwrap_model(transformer))):
15881596
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
1597+
modules_to_save["transformer"] = model
15891598
elif isinstance(model, type(unwrap_model(text_encoder_one))):
15901599
if args.train_text_encoder: # when --train_text_encoder_ti we don't save the layers
15911600
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
1601+
modules_to_save["text_encoder"] = model
15921602
elif isinstance(model, type(unwrap_model(text_encoder_two))):
15931603
pass # when --train_text_encoder_ti and --enable_t5_ti we don't save the layers
15941604
else:
@@ -1601,6 +1611,7 @@ def save_model_hook(models, weights, output_dir):
16011611
output_dir,
16021612
transformer_lora_layers=transformer_lora_layers_to_save,
16031613
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
1614+
**_collate_lora_metadata(modules_to_save),
16041615
)
16051616
if args.train_text_encoder_ti:
16061617
embedding_handler.save_embeddings(f"{args.output_dir}/{Path(args.output_dir).name}_emb.safetensors")
@@ -2359,16 +2370,19 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23592370
# Save the lora layers
23602371
accelerator.wait_for_everyone()
23612372
if accelerator.is_main_process:
2373+
modules_to_save = {}
23622374
transformer = unwrap_model(transformer)
23632375
if args.upcast_before_saving:
23642376
transformer.to(torch.float32)
23652377
else:
23662378
transformer = transformer.to(weight_dtype)
23672379
transformer_lora_layers = get_peft_model_state_dict(transformer)
2380+
modules_to_save["transformer"] = transformer
23682381

23692382
if args.train_text_encoder:
23702383
text_encoder_one = unwrap_model(text_encoder_one)
23712384
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder_one.to(torch.float32))
2385+
modules_to_save["text_encoder"] = text_encoder_one
23722386
else:
23732387
text_encoder_lora_layers = None
23742388

@@ -2377,6 +2391,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
23772391
save_directory=args.output_dir,
23782392
transformer_lora_layers=transformer_lora_layers,
23792393
text_encoder_lora_layers=text_encoder_lora_layers,
2394+
**_collate_lora_metadata(modules_to_save),
23802395
)
23812396

23822397
if args.train_text_encoder_ti:

0 commit comments

Comments
 (0)