Skip to content

Commit f9a925f

Browse files
authored
init for black-forest flux control tools, canny/depth/fill (#71)
* init for bfl control * add missing requirements * code purify
1 parent f3cf5ff commit f9a925f

File tree

17 files changed

+285
-16
lines changed

17 files changed

+285
-16
lines changed

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,13 @@ If you have any questions or feedback, please scan the QR code below, or send em
7777
<img src="assets/dingtalk.png" alt="dingtalk" width="400" />
7878
</div>
7979

80+
## Contributing
81+
We welcome contributions to DiffSynth-Engine. After Install from source, we recommand developers install this project using following command to setup the development environment.
82+
```bash
83+
pip install -e '.[dev]'
84+
```
85+
TODO: Please refer to [CONTRIBUTING.md](./CONTRIBUTING.md) for more details.
86+
8087
## License
8188
This project is licensed under the Apache License 2.0. See the LICENSE file for details.
8289

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ class FluxDiT(PreTrainedModel):
322322

323323
def __init__(
324324
self,
325+
in_channel: int = 64,
325326
attn_impl: Optional[str] = None,
326327
device: str = "cuda:0",
327328
dtype: torch.dtype = torch.bfloat16,
@@ -336,7 +337,8 @@ def __init__(
336337
nn.Linear(3072, 3072, device=device, dtype=dtype),
337338
)
338339
self.context_embedder = nn.Linear(4096, 3072, device=device, dtype=dtype)
339-
self.x_embedder = nn.Linear(64, 3072, device=device, dtype=dtype)
340+
# normal flux has 64 channels, bfl canny and depth has 128 channels, bfl fill has 384 channels, bfl redux has 64 channels
341+
self.x_embedder = nn.Linear(in_channel, 3072, device=device, dtype=dtype)
340342

341343
self.blocks = nn.ModuleList(
342344
[FluxDoubleTransformerBlock(3072, 24, attn_impl=attn_impl, device=device, dtype=dtype) for _ in range(19)]
@@ -430,13 +432,15 @@ def from_state_dict(
430432
state_dict: Dict[str, torch.Tensor],
431433
device: str,
432434
dtype: torch.dtype,
435+
in_channel: int = 64,
433436
attn_impl: Optional[str] = None,
434437
):
435438
with no_init_weights():
436439
model = torch.nn.utils.skip_init(
437440
cls,
438441
device=device,
439442
dtype=dtype,
443+
in_channel=in_channel,
440444
attn_impl=attn_impl,
441445
)
442446
model = model.requires_grad_(False) # for loading gguf

diffsynth_engine/pipelines/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from typing import Dict, List, Tuple
55
from PIL import Image
66
from dataclasses import dataclass
7-
from diffsynth_engine.utils.loader import load_file
87
from diffsynth_engine.utils.offload import enable_sequential_cpu_offload
98
from diffsynth_engine.utils.gguf import load_gguf_checkpoint
109
from diffsynth_engine.utils import logging
10+
from diffsynth_engine.utils.loader import load_file
1111
from diffsynth_engine.utils.platform import empty_cache
1212

1313
logger = logging.get_logger(__name__)

diffsynth_engine/pipelines/flux_image.py

Lines changed: 60 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from enum import Enum
12
import re
23
import os
34
import torch
@@ -27,6 +28,8 @@
2728
from diffsynth_engine.utils.download import fetch_model
2829
from diffsynth_engine.utils.platform import empty_cache
2930

31+
from einops import rearrange
32+
3033
logger = logging.get_logger(__name__)
3134

3235

@@ -244,11 +247,25 @@ def accumulate(result, new_item):
244247
ImageType = Union[Image.Image, torch.Tensor, List[Image.Image], List[torch.Tensor]]
245248

246249

250+
class ControlType(Enum):
251+
normal = "normal"
252+
bfl_control = "bfl_control"
253+
bfl_fill = "bfl_fill"
254+
255+
def get_in_channel(self):
256+
if self == ControlType.normal:
257+
return 64
258+
elif self == ControlType.bfl_control:
259+
return 128
260+
elif self == ControlType.bfl_fill:
261+
return 384
262+
263+
247264
@dataclass
248265
class ControlNetParams:
249-
model: nn.Module
250266
scale: float
251267
image: ImageType
268+
model: Optional[nn.Module] = None
252269
mask: Optional[ImageType] = None
253270
control_start: float = 0
254271
control_end: float = 1
@@ -287,6 +304,7 @@ def __init__(
287304
vae_tiled: bool = False,
288305
vae_tile_size: int = 256,
289306
vae_tile_stride: int = 256,
307+
control_type: ControlType = ControlType.normal,
290308
device: str = "cuda:0",
291309
dtype: torch.dtype = torch.bfloat16,
292310
):
@@ -312,6 +330,7 @@ def __init__(
312330
self.batch_cfg = batch_cfg
313331
self.ip_adapter = None
314332
self.redux = None
333+
self.control_type = control_type
315334
self.model_names = [
316335
"text_encoder_1",
317336
"text_encoder_2",
@@ -324,6 +343,7 @@ def __init__(
324343
def from_pretrained(
325344
cls,
326345
model_path_or_config: str | os.PathLike | FluxModelConfig,
346+
control_type: ControlType = ControlType.normal,
327347
device: str = "cuda:0",
328348
dtype: torch.dtype = torch.bfloat16,
329349
offload_mode: str | None = None,
@@ -364,7 +384,11 @@ def from_pretrained(
364384
tokenizer_2 = T5TokenizerFast.from_pretrained(FLUX_TOKENIZER_2_CONF_PATH)
365385
with LoRAContext():
366386
dit = FluxDiT.from_state_dict(
367-
dit_state_dict, device=init_device, dtype=model_config.dit_dtype, attn_impl=model_config.dit_attn_impl
387+
dit_state_dict,
388+
device=init_device,
389+
dtype=model_config.dit_dtype,
390+
in_channel=control_type.get_in_channel(),
391+
attn_impl=model_config.dit_attn_impl,
368392
)
369393
if load_text_encoder:
370394
text_encoder_1 = FluxTextEncoder1.from_state_dict(
@@ -386,6 +410,7 @@ def from_pretrained(
386410
vae_decoder=vae_decoder,
387411
vae_encoder=vae_encoder,
388412
load_text_encoder=load_text_encoder,
413+
control_type=control_type,
389414
device=device,
390415
dtype=dtype,
391416
)
@@ -535,6 +560,12 @@ def predict_noise(
535560
current_step: int,
536561
total_step: int,
537562
):
563+
if self.control_type != ControlType.normal:
564+
controlnet_param = controlnet_params[0]
565+
latents = torch.cat((latents, controlnet_param.image * controlnet_param.scale), dim=1)
566+
latents = latents.to(self.dtype)
567+
controlnet_params = []
568+
538569
double_block_output, single_block_output = self.predict_multicontrolnet(
539570
latents=latents,
540571
timestep=timestep,
@@ -547,7 +578,9 @@ def predict_noise(
547578
current_step=current_step,
548579
total_step=total_step,
549580
)
581+
550582
self.load_models_to_device(["dit"])
583+
551584
noise_pred = self.dit(
552585
hidden_states=latents,
553586
timestep=timestep,
@@ -600,16 +633,28 @@ def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, he
600633
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
601634
latent = self.encode_image(image)
602635
else:
603-
image = image.resize((width, height))
604-
mask = mask.resize((width, height))
605-
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
606-
mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype)
607-
masked_image = image.clone()
608-
masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
609-
latent = self.encode_image(masked_image)
610-
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
611-
mask = 1 - mask
612-
latent = torch.cat([latent, mask], dim=1)
636+
if self.control_type == ControlType.normal:
637+
image = image.resize((width, height))
638+
mask = mask.resize((width, height))
639+
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
640+
mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype)
641+
masked_image = image.clone()
642+
masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
643+
latent = self.encode_image(masked_image)
644+
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
645+
mask = 1 - mask
646+
latent = torch.cat([latent, mask], dim=1)
647+
elif self.control_type == ControlType.bfl_fill:
648+
image = image.resize((width, height))
649+
mask = mask.resize((width, height))
650+
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
651+
mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype)
652+
image = image * (1 - mask)
653+
image = self.encode_image(image)
654+
mask = rearrange(mask, "b 1 (h ph) (w pw) -> b (ph pw) h w", ph=8, pw=8)
655+
latent = torch.cat((image, mask), dim=1)
656+
else:
657+
raise ValueError(f"Unsupported mask latent prepare for controlnet type: {self.control_type}")
613658
return latent
614659

615660
def prepare_controlnet_params(self, controlnet_params: List[ControlNetParams], h, w):
@@ -706,6 +751,9 @@ def __call__(
706751
controlnet_params: List[ControlNetParams] | ControlNetParams = [],
707752
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
708753
):
754+
if self.control_type != ControlType.normal:
755+
assert controlnet_params and len(controlnet_params) == 1, "bfl_controlnet must have one controlnet"
756+
709757
if input_image is not None:
710758
width, height = input_image.size
711759
if not isinstance(controlnet_params, list):

diffsynth_engine/processor/__init__.py

Whitespace-only changes.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import cv2
2+
import numpy as np
3+
from PIL import Image
4+
5+
6+
class CannyProcessor:
7+
def __init__(
8+
self,
9+
device,
10+
low_threshold: int = 100,
11+
high_threshold: int = 200,
12+
):
13+
self.device = device
14+
self.low_threshold = low_threshold
15+
self.high_threshold = high_threshold
16+
17+
def __call__(self, image: Image.Image) -> Image.Image:
18+
image = np.array(image.convert("RGB"), dtype=np.uint8)
19+
output_image = cv2.Canny(image, self.low_threshold, self.high_threshold)
20+
output_image = Image.fromarray(output_image)
21+
return output_image
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
import numpy as np
2+
import torch
3+
import torch.nn.functional as F
4+
from PIL import Image
5+
from torchvision.transforms.functional import to_tensor, normalize, resize, to_pil_image
6+
7+
8+
from diffsynth_engine.utils.download import fetch_model
9+
from diffsynth_engine.utils.onnx import OnnxModel
10+
11+
12+
MODEL_ID = "muse/depth_anything_detector"
13+
REVISION = "20240801180053"
14+
MODEL_NAME = "depth_anything_detector.onnx"
15+
16+
17+
class DepthProcessor:
18+
def __init__(self, device):
19+
self.device = device
20+
model_path = fetch_model(model_uri=MODEL_ID, revision=REVISION, path=MODEL_NAME)
21+
self.model = OnnxModel(model_path, device=self.device)
22+
23+
def _image_preprocess(self, image: Image.Image) -> np.ndarray:
24+
image = resize(image, (518, 518))
25+
image = to_tensor(image)
26+
image = normalize(image, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
27+
image = image.unsqueeze(0).contiguous()
28+
return image.numpy()
29+
30+
def __call__(self, img: Image.Image) -> Image.Image:
31+
image = img
32+
w, h = image.size
33+
image = self._image_preprocess(image)
34+
depth = self.model(image)
35+
depth = torch.from_numpy(depth)
36+
depth: torch.Tensor = F.interpolate(depth[None], (h, w), mode="bilinear", align_corners=False)
37+
depth = depth.squeeze(0).squeeze(0)
38+
# 确保张量在 [0, 255] 范围内,并转换为 uint8 类型
39+
depth = torch.clamp(depth, 0, 255).byte()
40+
# 转换为 PIL Image 对象
41+
depth = to_pil_image(depth)
42+
return depth

diffsynth_engine/utils/onnx.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import onnxruntime
2+
import logging
3+
4+
logger = logging.getLogger(__name__)
5+
6+
7+
def to_numpy(tensor):
8+
return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()
9+
10+
11+
class OnnxModel:
12+
def __init__(self, model_path: str, device: str = "cuda:0"):
13+
self.model_path = model_path
14+
if "cuda" in device:
15+
self.session = onnxruntime.InferenceSession(
16+
model_path, providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
17+
)
18+
else:
19+
self.session = onnxruntime.InferenceSession(model_path, providers=["CPUExecutionProvider"])
20+
21+
def forward(self, *args, **kwargs):
22+
inputs = {}
23+
for key, value in kwargs.items():
24+
inputs[key] = value
25+
for i, arg in enumerate(args):
26+
name = self.session.get_inputs()[i].name
27+
if name in inputs:
28+
raise ValueError(f"the input name [{name}] is duplicated")
29+
inputs[name] = arg
30+
return self.session.run(None, inputs)[0]
31+
32+
def __call__(self, *args, **kwargs):
33+
return self.forward(*args, **kwargs)

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ dependencies = [
2929
"torchsde",
3030
"pillow",
3131
"imageio[ffmpeg]",
32-
"yunchang ; sys_platform == 'linux'"
32+
"yunchang ; sys_platform == 'linux'",
33+
"onnxruntime"
3334
]
3435

3536
[project.optional-dependencies]

tests/common/test_case.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import unittest
22
import os
3+
import time
34
import numpy as np
45
import torch
56
from pathlib import Path
@@ -79,7 +80,7 @@ def assertImageEqualAndSaveFailed(self, input_image: Image.Image, expect_image_p
7980
self.assertImageEqual(input_image, expect_image, threshold=threshold)
8081
except Exception as e:
8182
name = expect_image_path.split("/")[-1]
82-
input_image.save(f"{name}")
83+
input_image.save(f"save_{time.time()}_{name}")
8384
raise e
8485

8586

0 commit comments

Comments
 (0)