Skip to content

Commit d3cd6d9

Browse files
committed
feat: add web ui for core ml stable diffusion
1 parent 583cc04 commit d3cd6d9

File tree

3 files changed

+100
-1
lines changed

3 files changed

+100
-1
lines changed

python_coreml_stable_diffusion/web.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
2+
import python_coreml_stable_diffusion.pipeline as pipeline
3+
4+
import gradio as gr
5+
from diffusers import StableDiffusionPipeline
6+
7+
def init(args):
8+
pipeline.logger.info("Initializing PyTorch pipe for reference configuration")
9+
pytorch_pipe = StableDiffusionPipeline.from_pretrained(args.model_version,
10+
use_auth_token=True)
11+
12+
user_specified_scheduler = None
13+
if args.scheduler is not None:
14+
user_specified_scheduler = pipeline.SCHEDULER_MAP[
15+
args.scheduler].from_config(pytorch_pipe.scheduler.config)
16+
17+
coreml_pipe = pipeline.get_coreml_pipe(pytorch_pipe=pytorch_pipe,
18+
mlpackages_dir=args.i,
19+
model_version=args.model_version,
20+
compute_unit=args.compute_unit,
21+
scheduler_override=user_specified_scheduler)
22+
23+
24+
def infer(prompt, steps):
25+
pipeline.logger.info("Beginning image generation.")
26+
image = coreml_pipe(
27+
prompt=prompt,
28+
height=coreml_pipe.height,
29+
width=coreml_pipe.width,
30+
num_inference_steps=steps,
31+
)
32+
images = []
33+
images.append(image["images"][0])
34+
return images
35+
36+
37+
demo = gr.Blocks()
38+
39+
with demo:
40+
gr.Markdown(
41+
"<center><h1>Core ML Stable Diffusion</h1>Run Stable Diffusion on Apple Silicon with Core ML</center>")
42+
with gr.Group():
43+
with gr.Box():
44+
with gr.Row():
45+
with gr.Column():
46+
with gr.Row():
47+
text = gr.Textbox(
48+
label="Prompt",
49+
lines=11,
50+
placeholder="Enter your prompt",
51+
)
52+
with gr.Row():
53+
btn = gr.Button("Generate image")
54+
with gr.Row():
55+
steps = gr.Slider(label="Steps", minimum=1,
56+
maximum=50, value=10, step=1)
57+
with gr.Column():
58+
gallery = gr.Gallery(
59+
label="Generated image", elem_id="gallery"
60+
)
61+
62+
text.submit(infer, inputs=[text, steps], outputs=gallery)
63+
btn.click(infer, inputs=[text, steps], outputs=gallery)
64+
65+
demo.launch(debug=True, server_name="0.0.0.0")
66+
67+
68+
if __name__ == "__main__":
69+
parser = pipeline.argparse.ArgumentParser()
70+
71+
parser.add_argument(
72+
"-i",
73+
required=True,
74+
help=("Path to input directory with the .mlpackage files generated by "
75+
"python_coreml_stable_diffusion.torch2coreml"))
76+
parser.add_argument(
77+
"--model-version",
78+
default="CompVis/stable-diffusion-v1-4",
79+
help=
80+
("The pre-trained model checkpoint and configuration to restore. "
81+
"For available versions: https://huggingface.co/models?search=stable-diffusion"
82+
))
83+
parser.add_argument(
84+
"--compute-unit",
85+
choices=pipeline.get_available_compute_units(),
86+
default="ALL",
87+
help=("The compute units to be used when executing Core ML models. "
88+
f"Options: {pipeline.get_available_compute_units()}"))
89+
parser.add_argument(
90+
"--scheduler",
91+
choices=tuple(pipeline.SCHEDULER_MAP.keys()),
92+
default=None,
93+
help=("The scheduler to use for running the reverse diffusion process. "
94+
"If not specified, the default scheduler from the diffusers pipeline is utilized"))
95+
96+
args = parser.parse_args()
97+
init(args)

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@ coremltools
22
diffusers[torch]
33
torch
44
transformers
5-
scipy
5+
scipy
6+
gradio

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"torch",
2020
"transformers",
2121
"scipy",
22+
"gradio",
2223
],
2324
packages=find_packages(),
2425
classifiers=[

0 commit comments

Comments
 (0)