Skip to content

Commit 924c6c4

Browse files
authored
Dev/flux tool from state dict (#153)
* FluxOutpaintingTool & FluxReplaceByControlTool implement from_state_dict * flux tool add from_pretrained & from_state_dict * fix redux model path
1 parent 95b5958 commit 924c6c4

File tree

4 files changed

+164
-22
lines changed

4 files changed

+164
-22
lines changed

diffsynth_engine/tools/flux_inpainting_tool.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
1-
from diffsynth_engine import fetch_model, FluxPipelineConfig, FluxControlNet, ControlNetParams, FluxImagePipeline
2-
from typing import List, Tuple, Optional, Callable
1+
from diffsynth_engine import (
2+
fetch_model,
3+
FluxPipelineConfig,
4+
FluxControlNet,
5+
ControlNetParams,
6+
FluxImagePipeline,
7+
FluxStateDicts,
8+
)
9+
from typing import List, Tuple, Optional, Callable, Dict
310
from PIL import Image
411
import torch
512

613

714
class FluxInpaintingTool:
815
def __init__(
916
self,
17+
flux_pipe: FluxImagePipeline,
18+
controlnet: FluxControlNet
19+
):
20+
self.pipe = flux_pipe
21+
self.controlnet = controlnet
22+
23+
@classmethod
24+
def from_pretrained(
25+
cls,
1026
flux_model_path: str,
1127
device: str = "cuda:0",
1228
dtype: torch.dtype = torch.bfloat16,
@@ -18,14 +34,35 @@ def __init__(
1834
device=device,
1935
offload_mode=offload_mode,
2036
)
21-
self.pipe = FluxImagePipeline.from_pretrained(config)
22-
self.controlnet = FluxControlNet.from_pretrained(
37+
flux_pipe = FluxImagePipeline.from_pretrained(config)
38+
controlnet = FluxControlNet.from_pretrained(
2339
fetch_model(
2440
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", path="diffusion_pytorch_model.safetensors"
2541
),
2642
device=device,
2743
dtype=torch.bfloat16,
2844
)
45+
return cls(flux_pipe, controlnet)
46+
47+
@classmethod
48+
def from_state_dict(
49+
cls,
50+
flux_state_dicts: FluxStateDicts,
51+
controlnet_state_dict: Dict[str, torch.Tensor],
52+
device: str = "cuda:0",
53+
dtype: torch.dtype = torch.bfloat16,
54+
offload_mode: Optional[str] = None,
55+
):
56+
config = FluxPipelineConfig(
57+
model_path="",
58+
model_dtype=dtype,
59+
device=device,
60+
offload_mode=offload_mode,
61+
)
62+
flux_pipe = FluxImagePipeline.from_state_dict(flux_state_dicts, config)
63+
controlnet = FluxControlNet.from_state_dict(controlnet_state_dict, device, dtype)
64+
return cls(flux_pipe, controlnet)
65+
2966

3067
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
3168
self.pipe.load_loras(lora_list, fused, save_original_weight)

diffsynth_engine/tools/flux_outpainting_tool.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
1-
from diffsynth_engine import fetch_model, FluxPipelineConfig, FluxControlNet, ControlNetParams, FluxImagePipeline
2-
from typing import List, Tuple, Optional, Callable
1+
from diffsynth_engine import (
2+
fetch_model,
3+
FluxPipelineConfig,
4+
FluxControlNet,
5+
ControlNetParams,
6+
FluxImagePipeline,
7+
FluxStateDicts
8+
)
9+
from typing import List, Tuple, Optional, Callable, Dict
310
from PIL import Image
411
import torch
512

613

714
class FluxOutpaintingTool:
815
def __init__(
916
self,
17+
flux_pipe: FluxImagePipeline,
18+
controlnet: FluxControlNet,
19+
):
20+
self.pipe = flux_pipe
21+
self.controlnet = controlnet
22+
23+
@classmethod
24+
def from_pretrained(
25+
cls,
1026
flux_model_path: str,
1127
device: str = "cuda:0",
1228
dtype: torch.dtype = torch.bfloat16,
@@ -18,14 +34,35 @@ def __init__(
1834
device=device,
1935
offload_mode=offload_mode,
2036
)
21-
self.pipe = FluxImagePipeline.from_pretrained(config)
22-
self.controlnet = FluxControlNet.from_pretrained(
37+
flux_pipe = FluxImagePipeline.from_pretrained(config)
38+
controlnet = FluxControlNet.from_pretrained(
2339
fetch_model(
24-
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", path="diffusion_pytorch_model.safetensors"
40+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
41+
path="diffusion_pytorch_model.safetensors"
2542
),
2643
device=device,
27-
dtype=torch.bfloat16,
44+
dtype=torch.bfloat16
45+
)
46+
return cls(flux_pipe, controlnet)
47+
48+
@classmethod
49+
def from_state_dict(
50+
cls,
51+
flux_state_dicts: FluxStateDicts,
52+
controlnet_state_dict: Dict[str, torch.Tensor],
53+
device: str = "cuda:0",
54+
dtype: torch.dtype = torch.bfloat16,
55+
offload_mode: Optional[str] = None,
56+
):
57+
config = FluxPipelineConfig(
58+
model_path="",
59+
model_dtype=dtype,
60+
device=device,
61+
offload_mode=offload_mode,
2862
)
63+
flux_pipe = FluxImagePipeline.from_state_dict(flux_state_dicts, config)
64+
controlnet = FluxControlNet.from_state_dict(controlnet_state_dict, device, dtype)
65+
return cls(flux_pipe, controlnet)
2966

3067
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
3168
self.pipe.load_loras(lora_list, fused, save_original_weight)

diffsynth_engine/tools/flux_reference_tool.py

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
FluxIPAdapter,
66
FluxRedux,
77
fetch_model,
8+
FluxStateDicts
89
)
9-
from typing import List, Tuple, Optional
10+
from typing import List, Tuple, Optional, Dict
1011
from PIL import Image
1112
import torch
1213

@@ -18,8 +19,17 @@ class FluxReduxRefTool:
1819

1920
def __init__(
2021
self,
22+
flux_pipe: FluxImagePipeline,
23+
redux: FluxRedux,
24+
):
25+
self.pipe = flux_pipe
26+
self.pipe.load_redux(redux)
27+
28+
@classmethod
29+
def from_pretrained(
30+
cls,
2131
flux_model_path: str,
22-
load_text_encoder=True,
32+
load_text_encoder: bool = True,
2333
device: str = "cuda:0",
2434
dtype: torch.dtype = torch.bfloat16,
2535
offload_mode: Optional[str] = None,
@@ -31,10 +41,31 @@ def __init__(
3141
device=device,
3242
offload_mode=offload_mode,
3343
)
34-
self.pipe: FluxImagePipeline = FluxImagePipeline.from_pretrained(config)
44+
flux_pipe = FluxImagePipeline.from_pretrained(config)
3545
redux_model_path = fetch_model("muse/flux1-redux-dev", path="flux1-redux-dev.safetensors", revision="v1")
36-
flux_redux = FluxRedux.from_pretrained(redux_model_path, device=device)
37-
self.pipe.load_redux(flux_redux)
46+
redux = FluxRedux.from_pretrained(redux_model_path, device=device)
47+
return cls(flux_pipe, redux)
48+
49+
@classmethod
50+
def from_state_dict(
51+
cls,
52+
flux_state_dicts: FluxStateDicts,
53+
redux_state_dict: Dict[str, torch.Tensor],
54+
load_text_encoder: bool = True,
55+
device: str = "cuda:0",
56+
dtype: torch.dtype = torch.bfloat16,
57+
offload_mode: Optional[str] = None,
58+
):
59+
config = FluxPipelineConfig(
60+
model_path="",
61+
model_dtype=dtype,
62+
load_text_encoder=load_text_encoder,
63+
device=device,
64+
offload_mode=offload_mode,
65+
)
66+
flux_pipe = FluxImagePipeline.from_state_dict(flux_state_dicts, config)
67+
redux = FluxRedux.from_state_dict(redux_state_dict, device=device, dtype=dtype)
68+
return cls(flux_pipe, redux)
3869

3970
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
4071
self.pipe.load_loras(lora_list, fused, save_original_weight)

diffsynth_engine/tools/flux_replace_tool.py

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@
55
FluxImagePipeline,
66
FluxRedux,
77
fetch_model,
8+
FluxStateDicts
89
)
9-
from typing import List, Tuple, Optional, Callable
10+
from typing import List, Tuple, Optional, Callable, Dict
1011
from PIL import Image
1112
import torch
1213

@@ -19,8 +20,19 @@ class FluxReplaceByControlTool:
1920

2021
def __init__(
2122
self,
23+
flux_pipe: FluxImagePipeline,
24+
redux: FluxRedux,
25+
controlnet: FluxControlNet,
26+
):
27+
self.pipe = flux_pipe
28+
self.pipe.load_redux(redux)
29+
self.controlnet = controlnet
30+
31+
@classmethod
32+
def from_pretrained(
33+
cls,
2234
flux_model_path: str,
23-
load_text_encoder=True,
35+
load_text_encoder: bool = True,
2436
device: str = "cuda:0",
2537
dtype: torch.dtype = torch.bfloat16,
2638
offload_mode: Optional[str] = None,
@@ -32,17 +44,42 @@ def __init__(
3244
device=device,
3345
offload_mode=offload_mode,
3446
)
35-
self.pipe: FluxImagePipeline = FluxImagePipeline.from_pretrained(config)
47+
flux_pipe = FluxImagePipeline.from_pretrained(config)
3648
redux_model_path = fetch_model("muse/flux1-redux-dev", path="flux1-redux-dev.safetensors", revision="v1")
37-
flux_redux = FluxRedux.from_pretrained(redux_model_path, device=device)
38-
self.pipe.load_redux(flux_redux)
39-
self.controlnet = FluxControlNet.from_pretrained(
49+
redux = FluxRedux.from_pretrained(redux_model_path, device=device, dtype=dtype)
50+
controlnet = FluxControlNet.from_pretrained(
4051
fetch_model(
41-
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta", path="diffusion_pytorch_model.safetensors"
52+
"alimama-creative/FLUX.1-dev-Controlnet-Inpainting-Beta",
53+
path="diffusion_pytorch_model.safetensors"
4254
),
4355
device=device,
4456
dtype=torch.bfloat16,
4557
)
58+
return cls(flux_pipe, redux, controlnet)
59+
60+
@classmethod
61+
def from_state_dict(
62+
cls,
63+
flux_state_dicts: FluxStateDicts,
64+
redux_state_dict: Dict[str, torch.Tensor],
65+
controlnet_state_dict: Dict[str, torch.Tensor],
66+
load_text_encoder: bool = True,
67+
device: str = "cuda:0",
68+
dtype: torch.dtype = torch.bfloat16,
69+
offload_mode: Optional[str] = None,
70+
):
71+
config = FluxPipelineConfig(
72+
model_path="",
73+
model_dtype=dtype,
74+
load_text_encoder=load_text_encoder,
75+
device=device,
76+
offload_mode=offload_mode,
77+
)
78+
flux_pipe = FluxImagePipeline.from_state_dict(flux_state_dicts, config)
79+
redux = FluxRedux.from_state_dict(redux_state_dict, device=device, dtype=dtype)
80+
controlnet = FluxControlNet.from_state_dict(controlnet_state_dict, device=device, dtype=dtype)
81+
return cls(flux_pipe, redux, controlnet)
82+
4683

4784
def load_loras(self, lora_list: List[Tuple[str, float]], fused: bool = True, save_original_weight: bool = False):
4885
self.pipe.load_loras(lora_list, fused, save_original_weight)

0 commit comments

Comments
 (0)