Skip to content

Commit e2f328b

Browse files
authored
Merge branch 'main' into update-skyreels-v2
2 parents fe3af91 + d032408 commit e2f328b

File tree

2 files changed

+79
-6
lines changed

2 files changed

+79
-6
lines changed

docs/source/en/api/pipelines/flux.md

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,67 @@ if integrity_checker.test_image(image_):
316316
raise ValueError("Your image has been flagged. Choose another prompt/image or try again.")
317317
```
318318

319+
### Kontext Inpainting
320+
`FluxKontextInpaintPipeline` enables image modification within a fixed mask region. It currently supports both text-based conditioning and image-reference conditioning.
321+
<hfoptions id="kontext-inpaint">
322+
<hfoption id="text-only">
323+
324+
325+
```python
326+
import torch
327+
from diffusers import FluxKontextInpaintPipeline
328+
from diffusers.utils import load_image
329+
330+
prompt = "Change the yellow dinosaur to green one"
331+
img_url = (
332+
"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_input.jpeg?raw=true"
333+
)
334+
mask_url = (
335+
"https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/dinosaur_mask.png?raw=true"
336+
)
337+
338+
source = load_image(img_url)
339+
mask = load_image(mask_url)
340+
341+
pipe = FluxKontextInpaintPipeline.from_pretrained(
342+
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
343+
)
344+
pipe.to("cuda")
345+
346+
image = pipe(prompt=prompt, image=source, mask_image=mask, strength=1.0).images[0]
347+
image.save("kontext_inpainting_normal.png")
348+
```
349+
</hfoption>
350+
<hfoption id="image conditioning">
351+
352+
```python
353+
import torch
354+
from diffusers import FluxKontextInpaintPipeline
355+
from diffusers.utils import load_image
356+
357+
pipe = FluxKontextInpaintPipeline.from_pretrained(
358+
"black-forest-labs/FLUX.1-Kontext-dev", torch_dtype=torch.bfloat16
359+
)
360+
pipe.to("cuda")
361+
362+
prompt = "Replace this ball"
363+
img_url = "https://images.pexels.com/photos/39362/the-ball-stadion-football-the-pitch-39362.jpeg?auto=compress&cs=tinysrgb&dpr=1&w=500"
364+
mask_url = "https://github.com/ZenAI-Vietnam/Flux-Kontext-pipelines/blob/main/assets/ball_mask.png?raw=true"
365+
image_reference_url = "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcTah3x6OL_ECMBaZ5ZlJJhNsyC-OSMLWAI-xw&s"
366+
367+
source = load_image(img_url)
368+
mask = load_image(mask_url)
369+
image_reference = load_image(image_reference_url)
370+
371+
mask = pipe.mask_processor.blur(mask, blur_factor=12)
372+
image = pipe(
373+
prompt=prompt, image=source, mask_image=mask, image_reference=image_reference, strength=1.0
374+
).images[0]
375+
image.save("kontext_inpainting_ref.png")
376+
```
377+
</hfoption>
378+
</hfoptions>
379+
319380
## Combining Flux Turbo LoRAs with Flux Control, Fill, and Redux
320381

321382
We can combine Flux Turbo LoRAs with Flux Control and other pipelines like Fill and Redux to enable few-steps' inference. The example below shows how to do that for Flux Control LoRA for depth and turbo LoRA from [`ByteDance/Hyper-SD`](https://hf.co/ByteDance/Hyper-SD).
@@ -646,3 +707,15 @@ image.save("flux-fp8-dev.png")
646707
[[autodoc]] FluxFillPipeline
647708
- all
648709
- __call__
710+
711+
## FluxKontextPipeline
712+
713+
[[autodoc]] FluxKontextPipeline
714+
- all
715+
- __call__
716+
717+
## FluxKontextInpaintPipeline
718+
719+
[[autodoc]] FluxKontextInpaintPipeline
720+
- all
721+
- __call__

tests/pipelines/bria/test_pipeline_bria.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
)
2929
from diffusers.pipelines.bria import BriaPipeline
3030
from diffusers.utils.testing_utils import (
31+
backend_empty_cache,
3132
enable_full_determinism,
3233
numpy_cosine_similarity_distance,
33-
require_accelerator,
34-
require_torch_gpu,
34+
require_torch_accelerator,
3535
slow,
3636
torch_device,
3737
)
@@ -149,7 +149,7 @@ def test_image_output_shape(self):
149149
assert (output_height, output_width) == (expected_height, expected_width)
150150

151151
@unittest.skipIf(torch_device not in ["cuda", "xpu"], reason="float16 requires CUDA or XPU")
152-
@require_accelerator
152+
@require_torch_accelerator
153153
def test_save_load_float16(self, expected_max_diff=1e-2):
154154
components = self.get_dummy_components()
155155
for name, module in components.items():
@@ -237,20 +237,20 @@ def test_torch_dtype_dict(self):
237237

238238

239239
@slow
240-
@require_torch_gpu
240+
@require_torch_accelerator
241241
class BriaPipelineSlowTests(unittest.TestCase):
242242
pipeline_class = BriaPipeline
243243
repo_id = "briaai/BRIA-3.2"
244244

245245
def setUp(self):
246246
super().setUp()
247247
gc.collect()
248-
torch.cuda.empty_cache()
248+
backend_empty_cache(torch_device)
249249

250250
def tearDown(self):
251251
super().tearDown()
252252
gc.collect()
253-
torch.cuda.empty_cache()
253+
backend_empty_cache(torch_device)
254254

255255
def get_inputs(self, device, seed=0):
256256
generator = torch.Generator(device="cpu").manual_seed(seed)

0 commit comments

Comments
 (0)