Skip to content

Commit 3495f63

Browse files
committed
Achieve basic functionality for sd3 txt2img
1 parent 585f680 commit 3495f63

File tree

13 files changed

+1659
-146
lines changed

13 files changed

+1659
-146
lines changed

models/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ shark-turbine @ git+https://github.com/iree-org/iree-turbine.git@main
44
transformers==4.37.1
55
torchsde
66
accelerate
7-
diffusers @ git+https://github.com/nod-ai/diffusers@v0.28.2-shark
7+
diffusers @ git+https://github.com/nod-ai/diffusers@0.29.0.dev0-shark
88
brevitas @ git+https://github.com/Xilinx/brevitas.git@6695e8df7f6a2c7715b9ed69c4b78157376bb60b
99
# turbine tank downloading/uploading
1010
azure-storage-blob

models/setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def load_version_info():
5959
"sentencepiece",
6060
"transformers==4.37.1",
6161
"accelerate",
62-
"diffusers==0.24.0",
62+
"diffusers==0.29.0.dev0",
6363
"azure-storage-blob",
6464
"einops",
6565
],

models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def is_valid_file(arg):
9595
p.add_argument(
9696
"--guidance_scale",
9797
type=float,
98-
default=7.5,
98+
default=4,
9999
help="Scale by which to adjust prompt guidance to the unconditional noise prediction output of UNet after each iteration.",
100100
)
101101

@@ -207,9 +207,15 @@ def is_valid_file(arg):
207207
p.add_argument(
208208
"--vae_decomp_attn",
209209
type=bool,
210-
default=False,
210+
default=True,
211211
help="Decompose attention for VAE decode only at fx graph level",
212212
)
213+
p.add_argument(
214+
"--vae_dtype",
215+
type=str,
216+
default="fp32",
217+
help="Precision of VAE graph.",
218+
)
213219

214220
##############################################################################
215221
# SD3 script general options.
@@ -271,11 +277,7 @@ def is_valid_file(arg):
271277
default=None,
272278
help="Azure storage container name to download mlir files from.",
273279
)
274-
p.add_argument(
275-
"--export",
276-
type=str,
277-
default="all",
278-
help="clip, mmdit, vae, all")
280+
p.add_argument("--export", type=str, default="all", help="clip, mmdit, vae, all")
279281
p.add_argument(
280282
"--output",
281283
type=str,
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
# Copyrigh 2023 Nod Labs, Inc
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import os
8+
import sys
9+
10+
from iree import runtime as ireert
11+
from iree.compiler.ir import Context
12+
import numpy as np
13+
from shark_turbine.aot import *
14+
from turbine_models.custom_models.sd_inference import utils
15+
import torch
16+
import torch._dynamo as dynamo
17+
18+
import safetensors
19+
import argparse
20+
from turbine_models.turbine_tank import turbine_tank
21+
22+
SEED = 1
23+
24+
25+
def export_vae(
26+
model,
27+
height,
28+
width,
29+
compile_to="torch",
30+
external_weight_prefix=None,
31+
device=None,
32+
target_triple=None,
33+
max_alloc="",
34+
upload_ir=False,
35+
dtype=torch.float32,
36+
):
37+
mapper = {}
38+
utils.save_external_weights(mapper, model, "safetensors", external_weight_prefix)
39+
latent_shape = [1, 16, height // 8, width // 8]
40+
input_arg = torch.empty(latent_shape)
41+
input_arg = (input_arg.to(dtype),)
42+
if external_weight_prefix != None and len(external_weight_prefix) > 1:
43+
externalize_module_parameters(model)
44+
45+
exported = export(model, args=input_arg)
46+
47+
module_str = str(exported.mlir_module)
48+
safe_name = utils.create_safe_name(str(dtype).lstrip("torch."), "_mmdit")
49+
if compile_to != "vmfb":
50+
return module_str
51+
else:
52+
print("compiling to vmfb")
53+
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
54+
return module_str
55+
56+
57+
def export_unet_dynamic(
58+
unet_model,
59+
height,
60+
width,
61+
compile_to="torch",
62+
external_weight_path=None,
63+
device=None,
64+
target_triple=None,
65+
max_alloc="",
66+
upload_ir=False,
67+
dtype=torch.float32,
68+
):
69+
cond_shape = [1, 154, 4096] # 77, 4096]
70+
pool_shape = [1, 2048]
71+
latent_shape = [1, 16, height // 8, width // 8]
72+
if dtype == torch.float16:
73+
unet_model = unet_model.half()
74+
mapper = {}
75+
utils.save_external_weights(mapper, unet_model, "safetensors", external_weight_path)
76+
77+
if weights_only:
78+
return external_weight_path
79+
80+
fxb = FxProgramsBuilder(unet_model)
81+
82+
sigmas = torch.export.Dim("sigmas")
83+
dynamic_shapes = {"sigmas": {0: sigmas}, "latent": {}, "noise": {}}
84+
example_init_args = [
85+
torch.empty([19], dtype=dtype),
86+
torch.empty(latent_shape, dtype=dtype),
87+
torch.empty(latent_shape, dtype=dtype),
88+
]
89+
example_sampling_args = [
90+
torch.empty(latent_shape, dtype=dtype),
91+
torch.empty(1, dtype=dtype),
92+
torch.empty(1, dtype=dtype),
93+
torch.empty(cond_shape, dtype=dtype),
94+
torch.empty(pool_shape, dtype=dtype),
95+
torch.empty(cond_shape, dtype=dtype),
96+
torch.empty(pool_shape, dtype=dtype),
97+
torch.empty(1, dtype=dtype),
98+
]
99+
100+
@fxb.export_program(args=(example_init_args,), dynamic_shapes=dynamic_shapes)
101+
def _initialize(module, inputs):
102+
# 1.0 is denoise currently symfloat not supported in fx_importer
103+
return module.init_dynamic(*inputs)
104+
105+
@fxb.export_program(args=(example_sampling_args,))
106+
def _do_sampling(module, inputs):
107+
return module.do_sampling(*inputs)
108+
109+
class CompiledTresleches(CompiledModule):
110+
initialize = _initialize
111+
do_sampling = _do_sampling
112+
113+
# _vae_decode = vae_decode
114+
115+
if external_weights:
116+
externalize_module_parameters(unet_model)
117+
save_module_parameters(external_weight_path, unet_model)
118+
119+
inst = CompiledTresleches(context=Context(), import_to="IMPORT")
120+
module_str = str(CompiledModule.get_mlir_module(inst))
121+
print("exported model")
122+
123+
safe_name = utils.create_safe_name(str(dtype).lstrip("torch."), "_mmdit")
124+
if compile_to != "vmfb":
125+
return module_str
126+
else:
127+
print("compiling to vmfb")
128+
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
129+
return module_str
130+
131+
132+
def export_preprocessor(
133+
model,
134+
compile_to="torch",
135+
external_weight_path=None,
136+
device=None,
137+
target_triple=None,
138+
max_alloc="",
139+
dtype=torch.float32,
140+
height=512,
141+
width=512,
142+
):
143+
external_weights = "safetensors"
144+
145+
def get_noise():
146+
latent = torch.ones(1, 16, height // 8, width // 8, device="cpu") * 0.0609
147+
generator = torch.manual_seed(SEED)
148+
return torch.randn(
149+
latent.size(),
150+
dtype=latent.dtype,
151+
layout=latent.layout,
152+
generator=generator,
153+
device="cpu",
154+
)
155+
156+
input_args = [torch.empty([1, 77, 2], dtype=torch.int64) for x in range(6)]
157+
input_args += get_noise()
158+
if dtype == torch.float16:
159+
model = model.half()
160+
161+
mapper = {}
162+
163+
utils.save_external_weights(mapper, model, external_weights, external_weight_path)
164+
165+
if external_weight_path != None and len(external_weight_path) > 1:
166+
print("externalizing weights")
167+
externalize_module_parameters(model)
168+
169+
exported = export(model, args=tuple(input_args))
170+
print("exported model")
171+
172+
# import_to = "INPUT" if compile_to == "linalg" else "IMPORT"
173+
# inst = CompiledTresleches(context=Context(), import_to=import_to)
174+
175+
# module_str = str(CompiledModule.get_mlir_module(inst))
176+
module_str = str(exported.mlir_module)
177+
safe_name = utils.create_safe_name("sd3", "clips")
178+
if compile_to != "vmfb":
179+
return module_str
180+
else:
181+
print("compiling to vmfb")
182+
utils.compile_to_vmfb(module_str, device, target_triple, max_alloc, safe_name)
183+
return module_str
184+
185+
186+
@torch.no_grad()
187+
def main(args):
188+
import turbine_sd3
189+
from safetensors import safe_open
190+
191+
vulkan_max_allocation = "4294967296" if args.device == "vulkan" else ""
192+
# st_file = "/mnt2/tresleches/models/sd3_8b_beta.safetensors"
193+
st_file = "/mnt2/tresleches/models/sd3_2b_512_alpha.safetensors"
194+
dtype = torch.float32
195+
if args.precision == "f16":
196+
dtype = torch.float16
197+
elif args.precision == "bf16":
198+
dtype = torch.bfloat16
199+
print(args.export)
200+
201+
if args.export in ["dynamic"]:
202+
print("exporting dynamic")
203+
unet_model = turbine_sd3.SD3Inferencer(
204+
model=st_file, vae=turbine_sd3.VAEFile, shift=1.0, dtype=dtype
205+
).eval()
206+
mod_str = export_unet_dynamic(
207+
unet_model=unet_model,
208+
height=args.height,
209+
width=args.width,
210+
compile_to=args.compile_to,
211+
external_weight_path=args.external_weight_path,
212+
device=args.device,
213+
target_triple=args.iree_target_triple,
214+
max_alloc=vulkan_max_allocation,
215+
upload_ir=False,
216+
dtype=dtype,
217+
)
218+
safe_name = utils.create_safe_name("hc_sd3", "-unet")
219+
with open(f"{safe_name}.mlir", "w+") as f:
220+
f.write(mod_str)
221+
print("Saved to", safe_name + ".mlir")
222+
export_pre = args.export in ["all", "clip"]
223+
print(export_pre)
224+
if export_pre:
225+
print("exporting preprocessor")
226+
pre = turbine_sd3.Preprocess()
227+
mod_str = export_preprocessor(
228+
model=pre,
229+
compile_to=args.compile_to,
230+
external_weight_path=args.external_weight_path,
231+
device=args.device,
232+
target_triple=args.iree_target_triple,
233+
max_alloc=vulkan_max_allocation,
234+
dtype=dtype,
235+
height=args.height,
236+
width=args.width,
237+
)
238+
safe_name = utils.create_safe_name("hc_sd3", "_preprocess")
239+
with open(f"{safe_name}.mlir", "w+") as f:
240+
f.write(mod_str)
241+
print("Saved to", safe_name + ".mlir")
242+
should_export_vae = args.export in ["all", "vae"]
243+
if should_export_vae:
244+
print("exporting vae")
245+
from turbine_impls import SDVAE
246+
247+
with turbine_sd3.safe_open(
248+
turbine_sd3.VAEFile, framework="pt", device="cpu"
249+
) as f:
250+
vae = SDVAE(device="cpu", dtype=dtype).eval().cpu()
251+
prefix = ""
252+
if any(k.startswith("first_stage_model.") for k in f.keys()):
253+
prefix = "first_stage_model."
254+
turbine_sd3.load_into(f, vae, prefix, "cpu", dtype)
255+
print("Something")
256+
mod_str = export_vae(
257+
model=vae,
258+
height=args.height,
259+
width=args.width,
260+
compile_to=args.compile_to,
261+
external_weight_prefix=args.external_weight_path,
262+
device=args.device,
263+
target_triple=args.iree_target_triple,
264+
max_alloc=vulkan_max_allocation,
265+
dtype=dtype,
266+
)
267+
safe_name = utils.create_safe_name("hc_sd3", "_vae")
268+
with open(f"{safe_name}.mlir", "w+") as f:
269+
f.write(mod_str)
270+
print("Saved to", safe_name + ".mlir")
271+
272+
273+
if __name__ == "__main__":
274+
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
275+
276+
torch._dynamo.config.capture_scalar_outputs = True
277+
main(args)

0 commit comments

Comments
 (0)