Skip to content

Commit 7505f91

Browse files
committed
Add option for sequential unet predictions
1 parent 6365319 commit 7505f91

File tree

2 files changed

+39
-10
lines changed

2 files changed

+39
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -514,6 +514,7 @@ Please refer to the help menu for all available arguments: `python -m python_cor
514514
- `--scheduler`: If you would like to experiment with different schedulers, you may specify it here. For available options, please see the help menu. You may also specify a custom number of inference steps by `--num-inference-steps` which defaults to 50.
515515
- `--controlnet`: ControlNet models specified with this option are used in image generation. Use this option in the format `--controlnet lllyasviel/sd-controlnet-mlsd lllyasviel/sd-controlnet-depth` and make sure to use `--controlnet-inputs` in conjunction.
516516
- `--controlnet-inputs`: Image inputs corresponding to each ControlNet model. Please provide image paths in same order as models in `--controlnet`, for example: `--controlnet-inputs image_mlsd image_depth`.
517+
- `--unet-batch-one`: Do not batch unet predictions for the prompt and negative prompt. This requires the unet has been converted with a batch size of one, see `--unet-batch-one` option in conversion script.
517518

518519
</details>
519520

python_coreml_stable_diffusion/pipeline.py

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -416,9 +416,10 @@ def __call__(
416416
callback=None,
417417
callback_steps=1,
418418
controlnet_cond=None,
419-
original_size: Optional[Tuple[int, int]] = None,
420-
crops_coords_top_left: Tuple[int, int] = (0, 0),
421-
target_size: Optional[Tuple[int, int]] = None,
419+
original_size: Optional[Tuple[int, int]]=None,
420+
crops_coords_top_left: Tuple[int, int]=(0, 0),
421+
target_size: Optional[Tuple[int, int]]=None,
422+
unet_batch_one=False,
422423
**kwargs,
423424
):
424425
# 1. Check inputs. Raise error if not correct
@@ -525,16 +526,38 @@ def __call__(
525526
# predict the noise residual
526527
unet_additional_kwargs.update(control_net_additional_residuals)
527528

528-
noise_pred = self.unet(
529-
sample=latent_model_input.astype(np.float16),
530-
timestep=timestep,
531-
encoder_hidden_states=text_embeddings.astype(np.float16),
532-
**unet_additional_kwargs,
533-
)["noise_pred"]
529+
# get prediction from unet
530+
if not (unet_batch_one and do_classifier_free_guidance):
531+
noise_pred = self.unet(
532+
sample=latent_model_input.astype(np.float16),
533+
timestep=timestep,
534+
encoder_hidden_states=text_embeddings.astype(np.float16),
535+
**unet_additional_kwargs,
536+
)["noise_pred"]
537+
538+
if do_classifier_free_guidance:
539+
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
540+
else:
541+
# query unet sequentially
542+
latent_model_input = latent_model_input.astype(np.float16)
543+
text_embeddings = text_embeddings.astype(np.float16)
544+
timestep = np.array([t,], np.float16)
545+
546+
noise_pred_uncond = self.unet(
547+
sample=np.expand_dims(latent_model_input[0], axis=0),
548+
timestep=timestep,
549+
encoder_hidden_states=np.expand_dims(text_embeddings[0], axis=0),
550+
**unet_additional_kwargs,
551+
)["noise_pred"]
552+
noise_pred_text = self.unet(
553+
sample=np.expand_dims(latent_model_input[1], axis=0),
554+
timestep=timestep,
555+
encoder_hidden_states=np.expand_dims(text_embeddings[1], axis=0),
556+
**unet_additional_kwargs,
557+
)["noise_pred"]
534558

535559
# perform guidance
536560
if do_classifier_free_guidance:
537-
noise_pred_uncond, noise_pred_text = np.split(noise_pred, 2)
538561
noise_pred = noise_pred_uncond + guidance_scale * (
539562
noise_pred_text - noise_pred_uncond)
540563

@@ -751,6 +774,7 @@ def main(args):
751774
guidance_scale=args.guidance_scale,
752775
controlnet_cond=controlnet_cond,
753776
negative_prompt=args.negative_prompt,
777+
unet_batch_one=args.unet_batch_one,
754778
)
755779

756780
out_path = get_image_path(args)
@@ -821,6 +845,10 @@ def main(args):
821845
"--negative-prompt",
822846
default=None,
823847
help="The negative text prompt to be used for text-to-image generation.")
848+
parser.add_argument(
849+
"--unet-batch-one",
850+
action="store_true",
851+
help="Do not batch unet predictions for the prompt and negative prompt.")
824852
parser.add_argument('--model-sources',
825853
default=None,
826854
choices=['packages', 'compiled'],

0 commit comments

Comments
 (0)