Skip to content

Commit 1c194da

Browse files
authored
Merge pull request #360 from TobyRoseman/unet-batch-1-or-2
Allow not using classifier free guidance
2 parents cab24e9 + e497107 commit 1c194da

File tree

3 files changed

+19
-3
lines changed

3 files changed

+19
-3
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,8 @@ This generally takes 15-20 minutes on an M1 MacBook Pro. Upon successful executi
490490

491491
- `--unet-support-controlnet`: enables a converted UNet model to receive additional inputs from ControlNet. This is required for generating image with using ControlNet and saved with a different name, `*_control-unet.mlpackage`, distinct from normal UNet. On the other hand, this UNet model can not work without ControlNet. Please use normal UNet for just txt2img.
492492

493+
- `--unet-batch-one`: use a batch size of one for the unet, this is needed if you do not want to do classifier free guidance, i.e. using a `guidance-scale` of less than one.
494+
493495
- `--convert-vae-encoder`: not required for text-to-image applications. Required for image-to-image applications in order to map the input image to the latent space.
494496

495497
</details>

python_coreml_stable_diffusion/pipeline.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,11 +506,16 @@ def __call__(
506506
if isinstance(latent_model_input, torch.Tensor):
507507
latent_model_input = latent_model_input.numpy()
508508

509+
if do_classifier_free_guidance:
510+
timestep = np.array([t, t], np.float16)
511+
else:
512+
timestep = np.array([t,], np.float16)
513+
509514
# controlnet
510515
if controlnet_cond:
511516
control_net_additional_residuals = self.run_controlnet(
512517
sample=latent_model_input,
513-
timestep=np.array([t, t]),
518+
timestep=timestep,
514519
encoder_hidden_states=text_embeddings,
515520
controlnet_cond=controlnet_cond,
516521
)
@@ -522,7 +527,7 @@ def __call__(
522527

523528
noise_pred = self.unet(
524529
sample=latent_model_input.astype(np.float16),
525-
timestep=np.array([t, t], np.float16),
530+
timestep=timestep,
526531
encoder_hidden_states=text_embeddings.astype(np.float16),
527532
**unet_additional_kwargs,
528533
)["noise_pred"]

python_coreml_stable_diffusion/torch2coreml.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ def forward(self, x):
757757
gc.collect()
758758

759759

760-
def convert_unet(pipe, args, model_name = None):
760+
def convert_unet(pipe, args, model_name=None):
761761
""" Converts the UNet component of Stable Diffusion
762762
"""
763763
if args.unet_support_controlnet:
@@ -783,6 +783,8 @@ def convert_unet(pipe, args, model_name = None):
783783
elif not os.path.exists(out_path):
784784
# Prepare sample input shapes and values
785785
batch_size = 2 # for classifier-free guidance
786+
if args.unet_batch_one:
787+
batch_size = 1 # for not using classifier-free guidance
786788
sample_shape = (
787789
batch_size, # B
788790
pipe.unet.config.in_channels, # C
@@ -1674,6 +1676,13 @@ def parser_spec():
16741676
"If specified, enable unet to receive additional inputs from controlnet. "
16751677
"Each input added to corresponding resnet output."
16761678
)
1679+
parser.add_argument(
1680+
"--unet-batch-one",
1681+
action="store_true",
1682+
help=
1683+
"If specified, a batch size of one will be used for the unet, this is needed if you do not want to do "
1684+
"classifier free guidance. Default unet batch size is two, which is needed for classifier free guidance."
1685+
)
16771686
parser.add_argument("--include-t5", action="store_true")
16781687

16791688
# Swift CLI Resource Bundling

0 commit comments

Comments
 (0)