Skip to content

Commit fabd52c

Browse files
authored
SD PNDMScheduler + Unet example through Turbine (#403)
TODO: Need to update the rest of the schedulers in diffusers upstream for e2e test to work. Xfailed for now.
1 parent f1c3d16 commit fabd52c

File tree

6 files changed

+438
-2
lines changed

6 files changed

+438
-2
lines changed

core/shark_turbine/dynamo/passes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
torch.ops.aten._log_softmax_backward_data,
4949
torch.ops.aten.lift_fresh_copy.default,
5050
torch.ops.aten._unsafe_index.Tensor,
51+
torch.ops.aten.unbind.int,
5152
# decompositions added manually in this file
5253
torch.ops.aten._scaled_dot_product_flash_attention.default,
5354
]
Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,178 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import os
8+
import sys
9+
10+
import torch
11+
from torch.fx.experimental.proxy_tensor import make_fx
12+
from shark_turbine.aot import *
13+
from iree import runtime as ireert
14+
import iree.compiler as ireec
15+
from iree.compiler.ir import Context
16+
import numpy as np
17+
18+
from turbine_models.custom_models.sd_inference import utils
19+
from diffusers import (
20+
UNet2DConditionModel,
21+
)
22+
23+
import safetensors
24+
import argparse
25+
26+
parser = argparse.ArgumentParser()
27+
parser.add_argument(
28+
"--hf_auth_token", type=str, help="The Hugging Face auth token, required"
29+
)
30+
parser.add_argument(
31+
"--hf_model_name",
32+
type=str,
33+
help="HF model name",
34+
default="CompVis/stable-diffusion-v1-4",
35+
)
36+
parser.add_argument(
37+
"--scheduler_id",
38+
type=str,
39+
help="Scheduler ID",
40+
default="PNDM",
41+
)
42+
parser.add_argument(
43+
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
44+
)
45+
parser.add_argument(
46+
"--batch_size", type=int, default=1, help="Batch size for inference"
47+
)
48+
parser.add_argument(
49+
"--height", type=int, default=512, help="Height of Stable Diffusion"
50+
)
51+
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
52+
parser.add_argument("--compile_to", type=str, help="torch, linalg, vmfb")
53+
parser.add_argument("--external_weight_path", type=str, default="")
54+
parser.add_argument(
55+
"--external_weights",
56+
type=str,
57+
default=None,
58+
help="saves ir/vmfb without global weights for size and readability, options [safetensors]",
59+
)
60+
parser.add_argument("--device", type=str, default="cpu", help="cpu, cuda, vulkan, rocm")
61+
# TODO: Bring in detection for target triple
62+
parser.add_argument(
63+
"--iree_target_triple",
64+
type=str,
65+
default="",
66+
help="Specify vulkan target triple or rocm/cuda target device.",
67+
)
68+
parser.add_argument("--vulkan_max_allocation", type=str, default="4294967296")
69+
70+
71+
class Scheduler(torch.nn.Module):
72+
def __init__(self, hf_model_name, num_inference_steps, scheduler):
73+
super().__init__()
74+
self.scheduler = scheduler
75+
self.scheduler.set_timesteps(num_inference_steps)
76+
self.unet = UNet2DConditionModel.from_pretrained(
77+
hf_model_name,
78+
subfolder="unet",
79+
)
80+
self.guidance_scale = 7.5
81+
82+
def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor:
83+
latents = latents * self.scheduler.init_noise_sigma
84+
for t in self.scheduler.timesteps:
85+
latent_model_input = torch.cat([latents] * 2)
86+
t = t.unsqueeze(0)
87+
latent_model_input = self.scheduler.scale_model_input(
88+
latent_model_input, timestep=t
89+
)
90+
unet_out = self.unet.forward(
91+
latent_model_input, t, encoder_hidden_states, return_dict=False
92+
)[0]
93+
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
94+
noise_pred = noise_pred_uncond + self.guidance_scale * (
95+
noise_pred_text - noise_pred_uncond
96+
)
97+
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
98+
return latents
99+
100+
101+
def export_scheduler(
102+
scheduler,
103+
hf_model_name,
104+
batch_size,
105+
height,
106+
width,
107+
hf_auth_token=None,
108+
compile_to="torch",
109+
external_weights=None,
110+
external_weight_path=None,
111+
device=None,
112+
target_triple=None,
113+
max_alloc=None,
114+
):
115+
mapper = {}
116+
utils.save_external_weights(
117+
mapper, scheduler, external_weights, external_weight_path
118+
)
119+
120+
encoder_hidden_states_sizes = (2, 77, 768)
121+
if hf_model_name == "stabilityai/stable-diffusion-2-1-base":
122+
encoder_hidden_states_sizes = (2, 77, 1024)
123+
124+
sample = (batch_size, 4, height // 8, width // 8)
125+
126+
class CompiledScheduler(CompiledModule):
127+
if external_weights:
128+
params = export_parameters(
129+
scheduler, external=True, external_scope="", name_mapper=mapper.get
130+
)
131+
else:
132+
params = export_parameters(scheduler)
133+
134+
def main(
135+
self,
136+
sample=AbstractTensor(*sample, dtype=torch.float32),
137+
encoder_hidden_states=AbstractTensor(
138+
*encoder_hidden_states_sizes, dtype=torch.float32
139+
),
140+
):
141+
return jittable(scheduler.forward)(sample, encoder_hidden_states)
142+
143+
import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
144+
inst = CompiledScheduler(context=Context(), import_to=import_to)
145+
146+
module_str = str(CompiledModule.get_mlir_module(inst))
147+
safe_name = utils.create_safe_name(hf_model_name, "-scheduler")
148+
if compile_to != "vmfb":
149+
return module_str
150+
else:
151+
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
152+
153+
154+
if __name__ == "__main__":
155+
args = parser.parse_args()
156+
schedulers = utils.get_schedulers(args.hf_model_name)
157+
scheduler = schedulers[args.scheduler_id]
158+
scheduler_module = Scheduler(
159+
args.hf_model_name, args.num_inference_steps, scheduler
160+
)
161+
mod_str = export_scheduler(
162+
scheduler_module,
163+
args.hf_model_name,
164+
args.batch_size,
165+
args.height,
166+
args.width,
167+
args.hf_auth_token,
168+
args.compile_to,
169+
args.external_weights,
170+
args.external_weight_path,
171+
args.device,
172+
args.iree_target_triple,
173+
args.vulkan_max_allocation,
174+
)
175+
safe_name = utils.create_safe_name(args.hf_model_name, "-scheduler")
176+
with open(f"{safe_name}.mlir", "w+") as f:
177+
f.write(mod_str)
178+
print("Saved to", safe_name + ".mlir")
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import argparse
8+
from turbine_models.model_runner import vmfbRunner
9+
from iree import runtime as ireert
10+
import torch
11+
from diffusers import (
12+
PNDMScheduler,
13+
UNet2DConditionModel,
14+
)
15+
16+
parser = argparse.ArgumentParser()
17+
18+
# TODO move common runner flags to generic flag file
19+
parser.add_argument(
20+
"--scheduler_id",
21+
type=str,
22+
help="Scheduler ID",
23+
default="PNDM",
24+
)
25+
parser.add_argument(
26+
"--num_inference_steps", type=int, default=50, help="Number of inference steps"
27+
)
28+
parser.add_argument(
29+
"--vmfb_path", type=str, default="", help="path to vmfb containing compiled module"
30+
)
31+
parser.add_argument(
32+
"--external_weight_path",
33+
type=str,
34+
default="",
35+
help="path to external weight parameters if model compiled without them",
36+
)
37+
parser.add_argument(
38+
"--compare_vs_torch",
39+
action="store_true",
40+
help="Runs both turbine vmfb and a torch model to compare results",
41+
)
42+
parser.add_argument(
43+
"--hf_model_name",
44+
type=str,
45+
help="HF model name",
46+
default="CompVis/stable-diffusion-v1-4",
47+
)
48+
parser.add_argument(
49+
"--hf_auth_token",
50+
type=str,
51+
help="The Hugging face auth token, required for some models",
52+
)
53+
parser.add_argument(
54+
"--device",
55+
type=str,
56+
default="local-task",
57+
help="local-sync, local-task, cuda, vulkan, rocm",
58+
)
59+
parser.add_argument(
60+
"--batch_size", type=int, default=1, help="Batch size for inference"
61+
)
62+
parser.add_argument(
63+
"--height", type=int, default=512, help="Height of Stable Diffusion"
64+
)
65+
parser.add_argument("--width", type=int, default=512, help="Width of Stable Diffusion")
66+
67+
68+
def run_scheduler(
69+
device,
70+
sample,
71+
encoder_hidden_states,
72+
vmfb_path,
73+
hf_model_name,
74+
hf_auth_token,
75+
external_weight_path,
76+
):
77+
runner = vmfbRunner(device, vmfb_path, external_weight_path)
78+
79+
inputs = [
80+
ireert.asdevicearray(runner.config.device, sample),
81+
ireert.asdevicearray(runner.config.device, encoder_hidden_states),
82+
]
83+
results = runner.ctx.modules.compiled_scheduler["main"](*inputs)
84+
return results
85+
86+
87+
def run_torch_scheduler(
88+
hf_model_name, scheduler, num_inference_steps, sample, encoder_hidden_states
89+
):
90+
class Scheduler(torch.nn.Module):
91+
def __init__(self, hf_model_name, num_inference_steps, scheduler):
92+
super().__init__()
93+
self.scheduler = scheduler
94+
self.scheduler.set_timesteps(num_inference_steps)
95+
self.unet = UNet2DConditionModel.from_pretrained(
96+
hf_model_name,
97+
subfolder="unet",
98+
)
99+
self.guidance_scale = 7.5
100+
101+
def forward(self, latents, encoder_hidden_states) -> torch.FloatTensor:
102+
latents = latents * self.scheduler.init_noise_sigma
103+
for t in self.scheduler.timesteps:
104+
latent_model_input = torch.cat([latents] * 2)
105+
t = t.unsqueeze(0)
106+
latent_model_input = self.scheduler.scale_model_input(
107+
latent_model_input, timestep=t
108+
)
109+
unet_out = self.unet.forward(
110+
latent_model_input, t, encoder_hidden_states, return_dict=False
111+
)[0]
112+
noise_pred_uncond, noise_pred_text = unet_out.chunk(2)
113+
noise_pred = noise_pred_uncond + self.guidance_scale * (
114+
noise_pred_text - noise_pred_uncond
115+
)
116+
latents = self.scheduler.step(
117+
noise_pred, t, latents, return_dict=False
118+
)[0]
119+
return latents
120+
121+
scheduler_module = Scheduler(hf_model_name, num_inference_steps, scheduler)
122+
results = scheduler_module.forward(sample, encoder_hidden_states)
123+
np_torch_output = results.detach().cpu().numpy()
124+
return np_torch_output
125+
126+
127+
if __name__ == "__main__":
128+
args = parser.parse_args()
129+
sample = torch.rand(
130+
args.batch_size, 4, args.height // 8, args.width // 8, dtype=torch.float32
131+
)
132+
if args.hf_model_name == "CompVis/stable-diffusion-v1-4":
133+
encoder_hidden_states = torch.rand(2, 77, 768, dtype=torch.float32)
134+
elif args.hf_model_name == "stabilityai/stable-diffusion-2-1-base":
135+
encoder_hidden_states = torch.rand(2, 77, 1024, dtype=torch.float32)
136+
137+
turbine_output = run_scheduler(
138+
args.device,
139+
sample,
140+
encoder_hidden_states,
141+
args.vmfb_path,
142+
args.hf_model_name,
143+
args.hf_auth_token,
144+
args.external_weight_path,
145+
)
146+
print(
147+
"TURBINE OUTPUT:",
148+
turbine_output.to_host(),
149+
turbine_output.to_host().shape,
150+
turbine_output.to_host().dtype,
151+
)
152+
153+
if args.compare_vs_torch:
154+
print("generating torch output: ")
155+
from turbine_models.custom_models.sd_inference import utils
156+
157+
schedulers = utils.get_schedulers(args.hf_model_name)
158+
scheduler = schedulers[args.scheduler_id]
159+
torch_output = run_torch_scheduler(
160+
args.hf_model_name,
161+
scheduler,
162+
args.num_inference_steps,
163+
sample,
164+
encoder_hidden_states,
165+
)
166+
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
167+
err = utils.largest_error(torch_output, turbine_output)
168+
print("Largest Error: ", err)
169+
assert err < 9e-3
170+
171+
# TODO: Figure out why we occasionally segfault without unlinking output variables
172+
turbine_output = None

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
import numpy as np
33
import safetensors
44
import re
5+
from diffusers import (
6+
PNDMScheduler,
7+
)
58

69

710
def save_external_weights(
@@ -35,6 +38,7 @@ def compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name):
3538
"--iree-llvmcpu-target-triple=x86_64-linux-gnu",
3639
"--iree-stream-resource-index-bits=64",
3740
"--iree-vm-target-index-bits=64",
41+
"--iree-flow-inline-constants-max-byte-length=1",
3842
]
3943
if device == "cpu":
4044
flags.append("--iree-llvmcpu-enable-ukernels=all")
@@ -86,3 +90,21 @@ def create_safe_name(hf_model_name, model_name_str):
8690
safe_name = hf_model_name.split("/")[-1].strip() + model_name_str
8791
safe_name = re.sub("-", "_", safe_name)
8892
return safe_name
93+
94+
95+
def get_schedulers(model_id):
96+
# TODO: Robust scheduler setup on pipeline creation -- if we don't
97+
# set batch_size here, the SHARK schedulers will
98+
# compile with batch size = 1 regardless of whether the model
99+
# outputs latents of a larger batch size, e.g. SDXL.
100+
# However, obviously, searching for whether the base model ID
101+
# contains "xl" is not very robust.
102+
103+
batch_size = 2 if "xl" in model_id.lower() else 1
104+
105+
schedulers = dict()
106+
schedulers["PNDM"] = PNDMScheduler.from_pretrained(
107+
model_id,
108+
subfolder="scheduler",
109+
)
110+
return schedulers

0 commit comments

Comments
 (0)