Skip to content

Commit 1280c9f

Browse files
下沉vae tiled相关参数到base (#50)
1 parent 13529a7 commit 1280c9f

File tree

5 files changed

+75
-53
lines changed

5 files changed

+75
-53
lines changed

diffsynth_engine/pipelines/base.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,14 @@ def convert(self, lora_state_dict: Dict[str, torch.Tensor]) -> Dict[str, Dict[st
2525
class BasePipeline:
2626
lora_converter = LoRAStateDictConverter()
2727

28-
def __init__(self, device="cuda:0", dtype=torch.float16):
28+
def __init__(self, vae_tiled, vae_tile_size, vae_tile_stride, device="cuda:0", dtype=torch.float16):
2929
super().__init__()
3030
self.device = device
3131
self.dtype = dtype
3232
self.offload_mode = None
33+
self.vae_tiled = vae_tiled
34+
self.vae_tile_size = vae_tile_size
35+
self.vae_tile_stride = vae_tile_stride
3336
self.model_names = []
3437

3538
@classmethod
@@ -140,13 +143,17 @@ def generate_noise(shape, seed=None, device="cpu", dtype=torch.float16):
140143
noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
141144
return noise
142145

143-
def encode_image(self, image: torch.Tensor, tiled=False, tile_size=64, tile_stride=32) -> torch.Tensor:
144-
latents = self.vae_encoder(image, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
146+
def encode_image(self, image: torch.Tensor) -> torch.Tensor:
147+
latents = self.vae_encoder(
148+
image, tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
149+
)
145150
return latents
146151

147-
def decode_image(self, latent: torch.Tensor, tiled=False, tile_size=64, tile_stride=32) -> torch.Tensor:
152+
def decode_image(self, latent: torch.Tensor) -> torch.Tensor:
148153
vae_dtype = self.vae_decoder.conv_in.weight.dtype
149-
image = self.vae_decoder(latent.to(vae_dtype), tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
154+
image = self.vae_decoder(
155+
latent.to(vae_dtype), tiled=self.vae_tiled, tile_size=self.vae_tile_size, tile_stride=self.vae_tile_stride
156+
)
150157
return image
151158

152159
def prepare_latents(

diffsynth_engine/pipelines/flux_image.py

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,19 @@ def __init__(
225225
vae_encoder: FluxVAEEncoder,
226226
use_cfg: bool = False,
227227
batch_cfg: bool = False,
228+
vae_tiled: bool = False,
229+
vae_tile_size: int = 256,
230+
vae_tile_stride: int = 256,
228231
device: str = "cuda:0",
229232
dtype: torch.dtype = torch.bfloat16,
230233
):
231-
super().__init__(device=device, dtype=dtype)
234+
super().__init__(
235+
vae_tiled=vae_tiled,
236+
vae_tile_size=vae_tile_size,
237+
vae_tile_stride=vae_tile_stride,
238+
device=device,
239+
dtype=dtype,
240+
)
232241
self.noise_scheduler = RecifitedFlowScheduler(shift=3.0, use_dynamic_shifting=True)
233242
self.sampler = FlowMatchEulerSampler()
234243
# models
@@ -474,9 +483,6 @@ def prepare_latents(
474483
denoising_strength: float,
475484
num_inference_steps: int,
476485
mu: float,
477-
tiled: bool = False,
478-
tile_size: int = 128,
479-
tile_stride: int = 64,
480486
):
481487
# Prepare scheduler
482488
if input_image is not None:
@@ -491,7 +497,7 @@ def prepare_latents(
491497
self.load_models_to_device(["vae_encoder"])
492498
noise = latents
493499
image = self.preprocess_image(input_image).to(device=self.device, dtype=self.dtype)
494-
latents = self.encode_image(image, tiled, tile_size, tile_stride)
500+
latents = self.encode_image(image)
495501
init_latents = latents.clone()
496502
latents = self.sampler.add_noise(latents, noise, sigma_start)
497503
else:
@@ -506,15 +512,15 @@ def prepare_masked_latent(self, image: Image.Image, mask: Image.Image | None, he
506512
if mask is None:
507513
image = image.resize((width, height))
508514
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
509-
latent = self.encode_image(image, tiled=False)
515+
latent = self.encode_image(image)
510516
else:
511517
image = image.resize((width, height))
512518
mask = mask.resize((width, height))
513519
image = self.preprocess_image(image).to(device=self.device, dtype=self.dtype)
514520
mask = self.preprocess_mask(mask).to(device=self.device, dtype=self.dtype)
515521
masked_image = image.clone()
516522
masked_image[(mask > 0.5).repeat(1, 3, 1, 1)] = -1
517-
latent = self.encode_image(masked_image, tiled=False)
523+
latent = self.encode_image(masked_image)
518524
mask = torch.nn.functional.interpolate(mask, size=(latent.shape[2], latent.shape[3]))
519525
mask = 1 - mask
520526
latent = torch.cat([latent, mask], dim=1)
@@ -585,9 +591,6 @@ def __call__(
585591
height: int = 1024,
586592
width: int = 1024,
587593
num_inference_steps: int = 30,
588-
tiled: bool = False,
589-
tile_size: int = 128,
590-
tile_stride: int = 64,
591594
seed: int | None = None,
592595
controlnet_params: List[ControlNetParams] | ControlNetParams = [],
593596
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
@@ -605,7 +608,7 @@ def __call__(
605608
image_seq_len = math.ceil(height // 16) * math.ceil(width // 16)
606609
mu = calculate_shift(image_seq_len)
607610
init_latents, latents, sigmas, timesteps = self.prepare_latents(
608-
noise, input_image, denoising_strength, num_inference_steps, mu, tiled, tile_size, tile_stride
611+
noise, input_image, denoising_strength, num_inference_steps, mu
609612
)
610613
# Initialize sampler
611614
self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas)
@@ -649,7 +652,7 @@ def __call__(
649652
progress_callback(i, len(timesteps), "DENOISING")
650653
# Decode image
651654
self.load_models_to_device(["vae_decoder"])
652-
vae_output = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
655+
vae_output = self.decode_image(latents)
653656
image = self.vae_output_to_image(vae_output)
654657
# Offload all models
655658
self.load_models_to_device([])

diffsynth_engine/pipelines/sd_image.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,10 +155,19 @@ def __init__(
155155
vae_decoder: SDVAEDecoder,
156156
vae_encoder: SDVAEEncoder,
157157
batch_cfg: bool = True,
158+
vae_tiled: bool = False,
159+
vae_tile_size: int = 256,
160+
vae_tile_stride: int = 256,
158161
device: str = "cuda",
159162
dtype: torch.dtype = torch.float16,
160163
):
161-
super().__init__(device=device, dtype=dtype)
164+
super().__init__(
165+
vae_tiled=vae_tiled,
166+
vae_tile_size=vae_tile_size,
167+
vae_tile_stride=vae_tile_stride,
168+
device=device,
169+
dtype=dtype,
170+
)
162171
self.noise_scheduler = ScaledLinearScheduler()
163172
self.sampler = EulerSampler()
164173
# models
@@ -310,9 +319,6 @@ def __call__(
310319
height: int = 1024,
311320
width: int = 1024,
312321
num_inference_steps: int = 20,
313-
tiled: bool = False,
314-
tile_size: int = 64,
315-
tile_stride: int = 32,
316322
seed: int | None = None,
317323
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
318324
):
@@ -322,7 +328,7 @@ def __call__(
322328
noise = self.generate_noise((1, 4, height // 8, width // 8), seed=seed, device=self.device, dtype=self.dtype)
323329

324330
init_latents, latents, sigmas, timesteps = self.prepare_latents(
325-
noise, input_image, denoising_strength, num_inference_steps, tiled, tile_size, tile_stride
331+
noise, input_image, denoising_strength, num_inference_steps
326332
)
327333
mask, overlay_image = None, None
328334
if mask_image is not None:
@@ -359,7 +365,7 @@ def __call__(
359365
latents = latents * mask + init_latents * (1 - mask)
360366
# Decode image
361367
self.load_models_to_device(["vae_decoder"])
362-
vae_output = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
368+
vae_output = self.decode_image(latents)
363369
image = self.vae_output_to_image(vae_output)
364370
# Paste Overlay Image
365371
if mask_image is not None:

diffsynth_engine/pipelines/sdxl_image.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -124,10 +124,19 @@ def __init__(
124124
vae_decoder: SDXLVAEDecoder,
125125
vae_encoder: SDXLVAEEncoder,
126126
batch_cfg: bool = True,
127+
vae_tiled: bool = False,
128+
vae_tile_size: int = 256,
129+
vae_tile_stride: int = 256,
127130
device: str = "cuda",
128131
dtype: torch.dtype = torch.float16,
129132
):
130-
super().__init__(device=device, dtype=dtype)
133+
super().__init__(
134+
vae_tiled=vae_tiled,
135+
vae_tile_size=vae_tile_size,
136+
vae_tile_stride=vae_tile_stride,
137+
device=device,
138+
dtype=dtype,
139+
)
131140
self.noise_scheduler = ScaledLinearScheduler()
132141
self.sampler = EulerSampler()
133142
# models
@@ -342,9 +351,6 @@ def __call__(
342351
height: int = 1024,
343352
width: int = 1024,
344353
num_inference_steps: int = 20,
345-
tiled: bool = False,
346-
tile_size: int = 64,
347-
tile_stride: int = 32,
348354
seed: int | None = None,
349355
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
350356
):
@@ -354,7 +360,7 @@ def __call__(
354360
noise = self.generate_noise((1, 4, height // 8, width // 8), seed=seed, device=self.device, dtype=self.dtype)
355361

356362
init_latents, latents, sigmas, timesteps = self.prepare_latents(
357-
noise, input_image, denoising_strength, num_inference_steps, tiled, tile_size, tile_stride
363+
noise, input_image, denoising_strength, num_inference_steps
358364
)
359365
mask, overlay_image = None, None
360366
if mask_image is not None:
@@ -402,7 +408,7 @@ def __call__(
402408
latents = latents * mask + init_latents * (1 - mask)
403409
# Decode image
404410
self.load_models_to_device(["vae_decoder"])
405-
vae_output = self.decode_image(latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
411+
vae_output = self.decode_image(latents)
406412
image = self.vae_output_to_image(vae_output)
407413

408414
if mask_image is not None:

diffsynth_engine/pipelines/wan_video.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,19 @@ def __init__(
128128
vae: WanVideoVAE,
129129
image_encoder: WanImageEncoder,
130130
batch_cfg: bool = False,
131+
vae_tiled: bool = True,
132+
vae_tile_size: Tuple[int, int] = (34, 34),
133+
vae_tile_stride: Tuple[int, int] = (18, 16),
131134
device="cuda",
132135
dtype=torch.bfloat16,
133136
):
134-
super().__init__(device=device, dtype=dtype)
137+
super().__init__(
138+
vae_tiled=vae_tiled,
139+
vae_tile_size=vae_tile_size,
140+
vae_tile_stride=vae_tile_stride,
141+
device=device,
142+
dtype=dtype,
143+
)
135144
self.noise_scheduler = RecifitedFlowScheduler(shift=5.0, sigma_min=0.001, sigma_max=0.999)
136145
self.sampler = FlowMatchEulerSampler()
137146
self.tokenizer = tokenizer
@@ -202,22 +211,26 @@ def tensor2video(self, frames):
202211
frames = [Image.fromarray(frame) for frame in frames]
203212
return frames
204213

205-
def encode_video(self, videos: torch.Tensor, tiled=True, tile_size=(34, 34), tile_stride=(18, 16)):
214+
def encode_video(self, videos: torch.Tensor):
206215
videos = videos.to(dtype=self.config.vae_dtype, device=self.device)
207-
latents = self.vae.encode(videos, device=self.device, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride)
216+
latents = self.vae.encode(
217+
videos,
218+
device=self.device,
219+
tiled=self.vae_tiled,
220+
tile_size=self.vae_tile_size,
221+
tile_stride=self.vae_tile_stride,
222+
)
208223
latents = latents.to(dtype=self.config.dit_dtype, device=self.device)
209224
return latents
210225

211-
def decode_video(
212-
self, latents, tiled=True, tile_size=(34, 34), tile_stride=(18, 16), progress_callback=None
213-
) -> List[torch.Tensor]:
226+
def decode_video(self, latents, progress_callback=None) -> List[torch.Tensor]:
214227
latents = latents.to(dtype=self.config.vae_dtype, device=self.device)
215228
videos = self.vae.decode(
216229
latents,
217230
device=self.device,
218-
tiled=tiled,
219-
tile_size=tile_size,
220-
tile_stride=tile_stride,
231+
tiled=self.vae_tiled,
232+
tile_size=self.vae_tile_size,
233+
tile_stride=self.vae_tile_stride,
221234
progress_callback=progress_callback,
222235
)
223236
videos = [video.to(dtype=self.config.dit_dtype, device=self.device) for video in videos]
@@ -297,9 +310,6 @@ def prepare_latents(
297310
input_video,
298311
denoising_strength,
299312
num_inference_steps,
300-
tiled=True,
301-
tile_size=(34, 34),
302-
tile_stride=(18, 16),
303313
):
304314
if input_video is not None:
305315
total_steps = num_inference_steps
@@ -311,9 +321,7 @@ def prepare_latents(
311321
noise = latents
312322
input_video = self.preprocess_images(input_video)
313323
input_video = torch.stack(input_video, dim=2)
314-
latents = self.encode_video(input_video, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride).to(
315-
dtype=latents.dtype, device=latents.device
316-
)
324+
latents = self.encode_video(input_video).to(dtype=latents.dtype, device=latents.device)
317325
init_latents = latents.clone()
318326
latents = self.sampler.add_noise(latents, noise, sigma_start)
319327
else:
@@ -336,9 +344,6 @@ def __call__(
336344
num_frames=81,
337345
cfg_scale=5.0,
338346
num_inference_steps=50,
339-
tiled=True,
340-
tile_size=(34, 34),
341-
tile_stride=(18, 16),
342347
progress_callback: Optional[Callable] = None, # def progress_callback(current, total, status)
343348
):
344349
assert height % 16 == 0 and width % 16 == 0, "height and width must be divisible by 16"
@@ -353,9 +358,6 @@ def __call__(
353358
input_video,
354359
denoising_strength,
355360
num_inference_steps,
356-
tiled=tiled,
357-
tile_size=tile_size,
358-
tile_stride=tile_stride,
359361
)
360362
self.sampler.initialize(init_latents=init_latents, timesteps=timesteps, sigmas=sigmas)
361363
# Encode prompts
@@ -392,9 +394,7 @@ def __call__(
392394

393395
# Decode
394396
self.load_models_to_device(["vae"])
395-
frames = self.decode_video(
396-
latents, tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, progress_callback=progress_callback
397-
)
397+
frames = self.decode_video(latents, progress_callback=progress_callback)
398398
frames = self.tensor2video(frames[0])
399399
return frames
400400

0 commit comments

Comments
 (0)