Skip to content

Commit 7a47d10

Browse files
committed
flux pipline: readability enhancement.
1 parent 5704376 commit 7a47d10

File tree

6 files changed

+69
-69
lines changed

6 files changed

+69
-69
lines changed

src/diffusers/pipelines/flux/pipeline_flux.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -195,13 +195,13 @@ def __init__(
195195
scheduler=scheduler,
196196
)
197197
self.vae_scale_factor = (
198-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
198+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
199199
)
200200
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
201201
self.tokenizer_max_length = (
202202
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
203203
)
204-
self.default_sample_size = 64
204+
self.default_sample_size = 128
205205

206206
def _get_t5_prompt_embeds(
207207
self,
@@ -425,9 +425,9 @@ def check_inputs(
425425

426426
@staticmethod
427427
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
428-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
429-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
430-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
428+
latent_image_ids = torch.zeros(height, width, 3)
429+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
430+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
431431

432432
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
433433

@@ -452,10 +452,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
452452
height = height // vae_scale_factor
453453
width = width // vae_scale_factor
454454

455-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
455+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
456456
latents = latents.permute(0, 3, 1, 4, 2, 5)
457457

458-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
458+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
459459

460460
return latents
461461

@@ -499,8 +499,8 @@ def prepare_latents(
499499
generator,
500500
latents=None,
501501
):
502-
height = 2 * (int(height) // self.vae_scale_factor)
503-
width = 2 * (int(width) // self.vae_scale_factor)
502+
height = int(height) // self.vae_scale_factor
503+
width = int(width) // self.vae_scale_factor
504504

505505
shape = (batch_size, num_channels_latents, height, width)
506506

@@ -517,7 +517,7 @@ def prepare_latents(
517517
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518518
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519519

520-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
520+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
521521

522522
return latents, latent_image_ids
523523

src/diffusers/pipelines/flux/pipeline_flux_controlnet.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -216,13 +216,13 @@ def __init__(
216216
controlnet=controlnet,
217217
)
218218
self.vae_scale_factor = (
219-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
219+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
220220
)
221221
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
222222
self.tokenizer_max_length = (
223223
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
224224
)
225-
self.default_sample_size = 64
225+
self.default_sample_size = 128
226226

227227
def _get_t5_prompt_embeds(
228228
self,
@@ -450,9 +450,9 @@ def check_inputs(
450450
@staticmethod
451451
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
452452
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
453-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
454-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
455-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
453+
latent_image_ids = torch.zeros(height, width, 3)
454+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
455+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
456456

457457
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
458458

@@ -479,10 +479,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
479479
height = height // vae_scale_factor
480480
width = width // vae_scale_factor
481481

482-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
482+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
483483
latents = latents.permute(0, 3, 1, 4, 2, 5)
484484

485-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
485+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
486486

487487
return latents
488488

@@ -498,13 +498,13 @@ def prepare_latents(
498498
generator,
499499
latents=None,
500500
):
501-
height = 2 * (int(height) // self.vae_scale_factor)
502-
width = 2 * (int(width) // self.vae_scale_factor)
501+
height = int(height) // self.vae_scale_factor
502+
width = int(width) // self.vae_scale_factor
503503

504504
shape = (batch_size, num_channels_latents, height, width)
505505

506506
if latents is not None:
507-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
507+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
508508
return latents.to(device=device, dtype=dtype), latent_image_ids
509509

510510
if isinstance(generator, list) and len(generator) != batch_size:
@@ -516,7 +516,7 @@ def prepare_latents(
516516
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
517517
latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
518518

519-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
519+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
520520

521521
return latents, latent_image_ids
522522

src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -228,13 +228,13 @@ def __init__(
228228
controlnet=controlnet,
229229
)
230230
self.vae_scale_factor = (
231-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
231+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
232232
)
233233
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
234234
self.tokenizer_max_length = (
235235
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
236236
)
237-
self.default_sample_size = 64
237+
self.default_sample_size = 128
238238

239239
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
240240
def _get_t5_prompt_embeds(
@@ -493,9 +493,9 @@ def check_inputs(
493493
@staticmethod
494494
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
495495
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
496-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
497-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
498-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
496+
latent_image_ids = torch.zeros(height, width, 3)
497+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
498+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
499499

500500
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
501501

@@ -522,10 +522,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
522522
height = height // vae_scale_factor
523523
width = width // vae_scale_factor
524524

525-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
525+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
526526
latents = latents.permute(0, 3, 1, 4, 2, 5)
527527

528-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
528+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
529529

530530
return latents
531531

@@ -549,11 +549,11 @@ def prepare_latents(
549549
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
550550
)
551551

552-
height = 2 * (int(height) // self.vae_scale_factor)
553-
width = 2 * (int(width) // self.vae_scale_factor)
552+
height = int(height) // self.vae_scale_factor
553+
width = int(width) // self.vae_scale_factor
554554

555555
shape = (batch_size, num_channels_latents, height, width)
556-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
556+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
557557

558558
if latents is not None:
559559
return latents.to(device=device, dtype=dtype), latent_image_ids
@@ -852,7 +852,7 @@ def __call__(
852852
control_mode = control_mode.reshape([-1, 1])
853853

854854
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
855-
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
855+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
856856
mu = calculate_shift(
857857
image_seq_len,
858858
self.scheduler.config.base_image_seq_len,

src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -231,7 +231,7 @@ def __init__(
231231
)
232232

233233
self.vae_scale_factor = (
234-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
234+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
235235
)
236236
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
237237
self.mask_processor = VaeImageProcessor(
@@ -244,7 +244,7 @@ def __init__(
244244
self.tokenizer_max_length = (
245245
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
246246
)
247-
self.default_sample_size = 64
247+
self.default_sample_size = 128
248248

249249
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
250250
def _get_t5_prompt_embeds(
@@ -520,9 +520,9 @@ def check_inputs(
520520
@staticmethod
521521
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
522522
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
523-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
524-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
525-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
523+
latent_image_ids = torch.zeros(height, width, 3)
524+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
525+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
526526

527527
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
528528

@@ -549,10 +549,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
549549
height = height // vae_scale_factor
550550
width = width // vae_scale_factor
551551

552-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
552+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
553553
latents = latents.permute(0, 3, 1, 4, 2, 5)
554554

555-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
555+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
556556

557557
return latents
558558

@@ -576,11 +576,11 @@ def prepare_latents(
576576
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
577577
)
578578

579-
height = 2 * (int(height) // self.vae_scale_factor)
580-
width = 2 * (int(width) // self.vae_scale_factor)
579+
height = int(height) // self.vae_scale_factor
580+
width = int(width) // self.vae_scale_factor
581581

582582
shape = (batch_size, num_channels_latents, height, width)
583-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
583+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
584584

585585
image = image.to(device=device, dtype=dtype)
586586
image_latents = self._encode_vae_image(image=image, generator=generator)
@@ -622,8 +622,8 @@ def prepare_mask_latents(
622622
device,
623623
generator,
624624
):
625-
height = 2 * (int(height) // self.vae_scale_factor)
626-
width = 2 * (int(width) // self.vae_scale_factor)
625+
height = int(height) // self.vae_scale_factor
626+
width = int(width) // self.vae_scale_factor
627627
# resize the mask to latents shape as we concatenate the mask to the latents
628628
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
629629
# and half precision
@@ -996,7 +996,7 @@ def __call__(
996996
# 6. Prepare timesteps
997997

998998
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
999-
image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
999+
image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (int(global_width) // self.vae_scale_factor // 2)
10001000
mu = calculate_shift(
10011001
image_seq_len,
10021002
self.scheduler.config.base_image_seq_len,

src/diffusers/pipelines/flux/pipeline_flux_img2img.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -212,13 +212,13 @@ def __init__(
212212
scheduler=scheduler,
213213
)
214214
self.vae_scale_factor = (
215-
2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
215+
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
216216
)
217217
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
218218
self.tokenizer_max_length = (
219219
self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
220220
)
221-
self.default_sample_size = 64
221+
self.default_sample_size = 128
222222

223223
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
224224
def _get_t5_prompt_embeds(
@@ -477,9 +477,9 @@ def check_inputs(
477477
@staticmethod
478478
# Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
479479
def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
480-
latent_image_ids = torch.zeros(height // 2, width // 2, 3)
481-
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
482-
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
480+
latent_image_ids = torch.zeros(height, width, 3)
481+
latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
482+
latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
483483

484484
latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
485485

@@ -506,10 +506,10 @@ def _unpack_latents(latents, height, width, vae_scale_factor):
506506
height = height // vae_scale_factor
507507
width = width // vae_scale_factor
508508

509-
latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
509+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
510510
latents = latents.permute(0, 3, 1, 4, 2, 5)
511511

512-
latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
512+
latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
513513

514514
return latents
515515

@@ -532,11 +532,11 @@ def prepare_latents(
532532
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
533533
)
534534

535-
height = 2 * (int(height) // self.vae_scale_factor)
536-
width = 2 * (int(width) // self.vae_scale_factor)
535+
height = int(height) // self.vae_scale_factor
536+
width = int(width) // self.vae_scale_factor
537537

538538
shape = (batch_size, num_channels_latents, height, width)
539-
latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
539+
latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
540540

541541
if latents is not None:
542542
return latents.to(device=device, dtype=dtype), latent_image_ids
@@ -736,7 +736,7 @@ def __call__(
736736

737737
# 4.Prepare timesteps
738738
sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
739-
image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
739+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
740740
mu = calculate_shift(
741741
image_seq_len,
742742
self.scheduler.config.base_image_seq_len,

0 commit comments

Comments
 (0)