Skip to content

Commit f576dc1

Browse files
committed
Merge branch 'Add-FA2' of https://github.com/ParagEkbote/diffusers into Add-FA2
2 parents ee922c3 + 40b24c6 commit f576dc1

File tree

1 file changed

+35
-0
lines changed

1 file changed

+35
-0
lines changed

main.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
import os
2+
os.environ["DIFFUSERS_ENABLE_HUB_KERNELS"] = "yes"
3+
4+
# Debug: Verify the env var is set
5+
print(f"DIFFUSERS_ENABLE_HUB_KERNELS = {os.environ.get('DIFFUSERS_ENABLE_HUB_KERNELS')}")
6+
7+
import torch
8+
from diffusers import FluxPipeline
9+
from diffusers.quantizers import PipelineQuantizationConfig
10+
11+
# Debug: Check if diffusers sees the env var
12+
from diffusers.models.attention_dispatch import DIFFUSERS_ENABLE_HUB_KERNELS
13+
print(f"Diffusers sees DIFFUSERS_ENABLE_HUB_KERNELS = {DIFFUSERS_ENABLE_HUB_KERNELS}")
14+
15+
# ✅ 3. Load pipeline with quantization
16+
model_id = "black-forest-labs/FLUX.1-dev"
17+
pipe = FluxPipeline.from_pretrained(
18+
model_id,
19+
torch_dtype=torch.bfloat16,
20+
quantization_config=PipelineQuantizationConfig(
21+
quant_backend="bitsandbytes_4bit",
22+
quant_kwargs={
23+
"load_in_4bit": True,
24+
"bnb_4bit_quant_type": "nf4",
25+
"bnb_4bit_compute_dtype": torch.bfloat16,
26+
},
27+
components_to_quantize=["transformer"],
28+
),
29+
).to("cuda")
30+
31+
pipe.transformer.set_attention_backend("_flash_hub")
32+
33+
prompt = "A cat holding a sign that says 'hello world'"
34+
image = pipe(prompt, num_inference_steps=28, guidance_scale=4.0).images[0]
35+
image.save("output.png")

0 commit comments

Comments
 (0)