Skip to content

Commit 56095ef

Browse files
authored
fix batch cfg (#93)
* fix batch cfg * remove use_cfg
1 parent 27dcb5f commit 56095ef

File tree

5 files changed

+49
-16
lines changed

5 files changed

+49
-16
lines changed

diffsynth_engine/models/basic/transformer_helper.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, dim: int, eps: float = 1e-6, device: str = "cuda:0", dtype: t
1515
self.silu = nn.SiLU()
1616

1717
def forward(self, x, emb):
18-
shift, scale = self.linear(self.silu(emb)).chunk(2, dim=1)
18+
shift, scale = self.linear(self.silu(emb)).unsqueeze(1).chunk(2, dim=1)
1919
return modulate(self.norm(x), shift, scale)
2020

2121

@@ -27,7 +27,7 @@ def __init__(self, dim, device: str, dtype: torch.dtype):
2727
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6, device=device, dtype=dtype)
2828

2929
def forward(self, x, emb):
30-
shift, scale, gate = self.linear(self.silu(emb)).chunk(3, dim=1)
30+
shift, scale, gate = self.linear(self.silu(emb)).unsqueeze(1).chunk(3, dim=1)
3131
return modulate(self.norm(x), shift, scale), gate
3232

3333

diffsynth_engine/pipelines/flux_image.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,6 @@ def __init__(
308308
vae_decoder: FluxVAEDecoder,
309309
vae_encoder: FluxVAEEncoder,
310310
load_text_encoder: bool = True,
311-
use_cfg: bool = False,
312311
batch_cfg: bool = False,
313312
vae_tiled: bool = False,
314313
vae_tile_size: int = 256,
@@ -336,7 +335,6 @@ def __init__(
336335
self.vae_decoder = vae_decoder
337336
self.vae_encoder = vae_encoder
338337
self.load_text_encoder = load_text_encoder
339-
self.use_cfg = use_cfg
340338
self.batch_cfg = batch_cfg
341339
self.ip_adapter = None
342340
self.redux = None
@@ -353,11 +351,15 @@ def __init__(
353351
def from_pretrained(
354352
cls,
355353
model_path_or_config: str | os.PathLike | FluxModelConfig,
354+
load_text_encoder: bool = True,
355+
batch_cfg: bool = False,
356+
vae_tiled: bool = False,
357+
vae_tile_size: int = 256,
358+
vae_tile_stride: int = 256,
356359
control_type: ControlType = ControlType.normal,
357360
device: str = "cuda:0",
358361
dtype: torch.dtype = torch.bfloat16,
359362
offload_mode: str | None = None,
360-
load_text_encoder: bool = True,
361363
parallelism: int = 1,
362364
use_cfg_parallel: bool = False,
363365
) -> "FluxImagePipeline":
@@ -454,6 +456,10 @@ def from_pretrained(
454456
vae_decoder=vae_decoder,
455457
vae_encoder=vae_encoder,
456458
load_text_encoder=load_text_encoder,
459+
batch_cfg=batch_cfg,
460+
vae_tiled=vae_tiled,
461+
vae_tile_size=vae_tile_size,
462+
vae_tile_stride=vae_tile_stride,
457463
control_type=control_type,
458464
device=device,
459465
dtype=dtype,
@@ -530,10 +536,9 @@ def predict_noise_with_cfg(
530536
controlnet_params: List[ControlNetParams],
531537
current_step: int,
532538
total_step: int,
533-
use_cfg: bool = False,
534539
batch_cfg: bool = False,
535540
):
536-
if cfg_scale <= 1.0 or not use_cfg:
541+
if cfg_scale <= 1.0:
537542
return self.predict_noise(
538543
latents,
539544
timestep,
@@ -583,6 +588,10 @@ def predict_noise_with_cfg(
583588
add_text_embeds = torch.cat([positive_add_text_embeds, negative_add_text_embeds], dim=0)
584589
latents = torch.cat([latents, latents], dim=0)
585590
timestep = torch.cat([timestep, timestep], dim=0)
591+
image_emb = torch.cat([image_emb, image_emb], dim=0) if image_emb is not None else None
592+
image_ids = torch.cat([image_ids, image_ids], dim=0)
593+
text_ids = torch.cat([text_ids, text_ids], dim=0)
594+
guidance = torch.cat([guidance, guidance], dim=0)
586595
positive_noise_pred, negative_noise_pred = self.predict_noise(
587596
latents,
588597
timestep,
@@ -676,8 +685,14 @@ def prepare_latents(
676685
num_inference_steps, mu=mu, sigma_min=1 / num_inference_steps, sigma_max=1.0
677686
)
678687
init_latents = latents.clone()
679-
sigmas, timesteps = sigmas.to(device=self.device, dtype=self.dtype), timesteps.to(device=self.device, dtype=self.dtype)
680-
init_latents, latents = init_latents.to(device=self.device, dtype=self.dtype), latents.to(device=self.device, dtype=self.dtype)
688+
sigmas, timesteps = (
689+
sigmas.to(device=self.device, dtype=self.dtype),
690+
timesteps.to(device=self.device, dtype=self.dtype),
691+
)
692+
init_latents, latents = (
693+
init_latents.to(device=self.device, dtype=self.dtype),
694+
latents.to(device=self.device, dtype=self.dtype),
695+
)
681696
return init_latents, latents, sigmas, timesteps
682697

683698
def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, height: int, width: int):
@@ -826,7 +841,7 @@ def __call__(
826841
# Encode prompts
827842
self.load_models_to_device(["text_encoder_1", "text_encoder_2"])
828843
positive_prompt_emb, positive_add_text_embeds = self.encode_prompt(prompt, clip_skip=clip_skip)
829-
if self.use_cfg and cfg_scale > 1:
844+
if cfg_scale > 1:
830845
negative_prompt_emb, negative_add_text_embeds = self.encode_prompt(negative_prompt, clip_skip=clip_skip)
831846
else:
832847
negative_prompt_emb, negative_add_text_embeds = None, None
@@ -868,7 +883,6 @@ def __call__(
868883
controlnet_params=controlnet_params,
869884
current_step=i,
870885
total_step=len(timesteps),
871-
use_cfg=self.use_cfg,
872886
batch_cfg=self.batch_cfg,
873887
)
874888
# Denoise

diffsynth_engine/pipelines/sd_image.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,13 @@ def __init__(
185185
def from_pretrained(
186186
cls,
187187
model_path_or_config: str | os.PathLike | SDModelConfig,
188+
batch_cfg: bool = True,
189+
vae_tiled: bool = False,
190+
vae_tile_size: int = 256,
191+
vae_tile_stride: int = 256,
188192
device: str = "cuda:0",
189193
dtype: torch.dtype = torch.float16,
190194
offload_mode: str | None = None,
191-
batch_cfg: bool = True,
192195
) -> "SDImagePipeline":
193196
if isinstance(model_path_or_config, str):
194197
model_config = SDModelConfig(unet_path=model_path_or_config)
@@ -232,6 +235,9 @@ def from_pretrained(
232235
vae_decoder=vae_decoder,
233236
vae_encoder=vae_encoder,
234237
batch_cfg=batch_cfg,
238+
vae_tiled=vae_tiled,
239+
vae_tile_size=vae_tile_size,
240+
vae_tile_stride=vae_tile_stride,
235241
device=device,
236242
dtype=dtype,
237243
)
@@ -262,7 +268,7 @@ def predict_noise_with_cfg(
262268
cfg_scale: float,
263269
batch_cfg: bool = True,
264270
):
265-
if cfg_scale < 1.0:
271+
if cfg_scale <= 1.0:
266272
return self.predict_noise(latents, timestep, positive_prompt_emb)
267273
if not batch_cfg:
268274
# cfg by predict noise one by one

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,13 @@ def __init__(
159159
def from_pretrained(
160160
cls,
161161
model_path_or_config: str | os.PathLike | SDXLModelConfig,
162+
batch_cfg: bool = True,
163+
vae_tiled: bool = False,
164+
vae_tile_size: int = 256,
165+
vae_tile_stride: int = 256,
162166
device: str = "cuda:0",
163167
dtype: torch.dtype = torch.float16,
164168
offload_mode: str | None = None,
165-
batch_cfg: bool = True,
166169
) -> "SDXLImagePipeline":
167170
if isinstance(model_path_or_config, str):
168171
model_config = SDXLModelConfig(
@@ -220,6 +223,9 @@ def from_pretrained(
220223
vae_decoder=vae_decoder,
221224
vae_encoder=vae_encoder,
222225
batch_cfg=batch_cfg,
226+
vae_tiled=vae_tiled,
227+
vae_tile_size=vae_tile_size,
228+
vae_tile_stride=vae_tile_stride,
223229
device=device,
224230
dtype=dtype,
225231
)

diffsynth_engine/pipelines/wan_video.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,9 +410,13 @@ def __call__(
410410
def from_pretrained(
411411
cls,
412412
model_path_or_config: str | WanModelConfig,
413+
shift: float | None = None,
414+
batch_cfg: bool = False,
415+
vae_tiled: bool = True,
416+
vae_tile_size: Tuple[int, int] = (34, 34),
417+
vae_tile_stride: Tuple[int, int] = (18, 16),
413418
device: str = "cuda",
414419
dtype: torch.dtype = torch.bfloat16,
415-
batch_cfg: bool = False,
416420
offload_mode: str | None = None,
417421
parallelism: int = 1,
418422
use_cfg_parallel: bool = False,
@@ -468,7 +472,7 @@ def from_pretrained(
468472
model_type = "1.3b-t2v"
469473

470474
# shift for different model_type
471-
shift = SHIFT_FACTORS[model_type]
475+
shift = SHIFT_FACTORS[model_type] if shift is None else shift
472476

473477
if parallelism > 1:
474478
parallel_config = cls.init_parallel_config(parallelism, use_cfg_parallel, model_config)
@@ -531,6 +535,9 @@ def from_pretrained(
531535
image_encoder=image_encoder,
532536
shift=shift,
533537
batch_cfg=batch_cfg,
538+
vae_tiled=vae_tiled,
539+
vae_tile_size=vae_tile_size,
540+
vae_tile_stride=vae_tile_stride,
534541
device=device,
535542
dtype=dtype,
536543
)

0 commit comments

Comments
 (0)