Skip to content

Commit fd2a2ba

Browse files
committed
Fixes for vae precision/attn decomposition, numerics validation
1 parent 8b775aa commit fd2a2ba

File tree

6 files changed

+54
-33
lines changed

6 files changed

+54
-33
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,12 @@ def is_valid_file(arg):
247247
default="fp16",
248248
help="Precision of Stable Diffusion weights and graph.",
249249
)
250+
p.add_argument(
251+
"--vae_precision",
252+
type=str,
253+
default=None,
254+
help="Precision of Stable Diffusion VAE weights and graph.",
255+
)
250256
p.add_argument(
251257
"--max_length", type=int, default=77, help="Sequence Length of Stable Diffusion"
252258
)
@@ -257,7 +263,7 @@ def is_valid_file(arg):
257263
p.add_argument(
258264
"--vae_decomp_attn",
259265
type=bool,
260-
default=True,
266+
default=False,
261267
help="Decompose attention for VAE decode only at fx graph level",
262268
)
263269
p.add_argument(

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def __init__(
4646
hf_model_name: str,
4747
height: int,
4848
width: int,
49-
shift: float,
5049
precision: str,
5150
max_length: int,
5251
batch_size: int,
@@ -59,10 +58,12 @@ def __init__(
5958
pipeline_dir: str = "./shark_vmfbs",
6059
external_weights_dir: str = "./shark_weights",
6160
external_weights: str = "safetensors",
62-
vae_decomp_attn: bool = True,
63-
custom_vae: str = "",
61+
vae_decomp_attn: bool = False,
6462
cpu_scheduling: bool = False,
63+
vae_precision: str = "fp32",
6564
scheduler_id: str = None, #compatibility only, always uses EulerFlowScheduler
65+
shift: float = 1.0,
66+
6667
):
6768
self.hf_model_name = hf_model_name
6869
# self.scheduler_id = scheduler_id
@@ -120,10 +121,11 @@ def __init__(
120121
self.external_weights_dir = external_weights_dir
121122
self.external_weights = external_weights
122123
self.vae_decomp_attn = vae_decomp_attn
123-
self.custom_vae = custom_vae
124+
self.custom_vae = None
124125
self.cpu_scheduling = cpu_scheduling
125126
self.torch_dtype = torch.float32 if self.precision == "fp32" else torch.float16
126-
self.vae_dtype = torch.float32
127+
self.vae_precision = vae_precision if vae_precision else self.precision
128+
self.vae_dtype = torch.float32 if vae_precision == "fp32" else torch.float16
127129
# TODO: set this based on user-inputted guidance scale and negative prompt.
128130
self.do_classifier_free_guidance = True # False if any(x in hf_model_name for x in ["turbo", "lightning"]) else True
129131

@@ -206,7 +208,12 @@ def is_prepared(self, vmfbs, weights):
206208
)
207209
if w_key == "clip":
208210
default_name = os.path.join(
209-
self.external_weights_dir, f"sd3_clip_fp16.irpa"
211+
self.external_weights_dir, f"sd3_text_encoders_{self.precision}.irpa"
212+
)
213+
if w_key == "mmdit":
214+
default_name = os.path.join(
215+
self.external_weights_dir,
216+
f"sd3_mmdit_{self.precision}." + self.external_weights,
210217
)
211218
if weights[w_key] is None and os.path.exists(default_name):
212219
weights[w_key] = os.path.join(default_name)
@@ -357,7 +364,7 @@ def export_submodel(
357364
self.batch_size,
358365
self.height,
359366
self.width,
360-
"fp32",
367+
self.vae_precision,
361368
"vmfb",
362369
self.external_weights,
363370
vae_external_weight_path,
@@ -586,7 +593,8 @@ def generate_images(
586593
dtype=self.vae_dtype,
587594
)
588595
else:
589-
latents = sample.astype("float32")
596+
vae_numpy_dtype = np.float32 if self.vae_precision == "fp32" else np.float16
597+
latents = sample.astype(vae_numpy_dtype)
590598

591599
vae_start = time.time()
592600
vae_out = self.runners["vae"].ctx.modules.compiled_vae["decode"](latents)
@@ -634,7 +642,7 @@ def generate_images(
634642
out_image = Image.fromarray(image)
635643
images.extend([[out_image]])
636644
if return_imgs:
637-
return images
645+
return images[0]
638646
for idx_batch, image_batch in enumerate(images):
639647
for idx, image in enumerate(image_batch):
640648
img_path = (
@@ -767,7 +775,6 @@ def run_diffusers_cpu(
767775
args.hf_model_name,
768776
args.height,
769777
args.width,
770-
args.shift,
771778
args.precision,
772779
args.max_length,
773780
args.batch_size,
@@ -779,9 +786,8 @@ def run_diffusers_cpu(
779786
args.decomp_attn,
780787
args.pipeline_dir,
781788
args.external_weights_dir,
782-
args.external_weights,
783-
args.vae_decomp_attn,
784-
custom_vae=None,
789+
external_weights=args.external_weights,
790+
vae_decomp_attn=args.vae_decomp_attn,
785791
cpu_scheduling=args.cpu_scheduling,
786792
vae_precision=args.vae_precision,
787793
)

models/turbine_models/custom_models/sd3_inference/sd3_vae_runner.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ def run_vae(
1515
):
1616
runner = vmfbRunner(device, vmfb_path, external_weight_path)
1717
inputs = [ireert.asdevicearray(runner.config.device, example_input)]
18-
results = runner.ctx.modules.compiled_vae["decode"](*inputs)
19-
18+
results = runner.ctx.modules.compiled_vae["decode"](*inputs).to_host()
19+
results = imagearray_from_vae_out(results)
2020
return results
2121

2222

@@ -32,11 +32,19 @@ def run_torch_vae(hf_model_name, variant, example_input):
3232
elif variant == "encode":
3333
results = vae_model.encode(example_input)
3434
np_torch_output = results.detach().cpu().numpy()
35+
np_torch_output = imagearray_from_vae_out(np_torch_output)
3536
return np_torch_output
3637

38+
def imagearray_from_vae_out(image):
39+
if image.ndim == 4:
40+
image = image[0]
41+
image = torch.from_numpy(image).cpu().permute(1, 2, 0).float().numpy()
42+
image = (image * 255).round().astype("uint8")
43+
return image
3744

3845
if __name__ == "__main__":
3946
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
47+
import numpy as np
4048

4149
dtype = torch.float16 if args.precision == "fp16" else torch.float32
4250
if args.vae_variant == "decode":
@@ -57,9 +65,9 @@ def run_torch_vae(hf_model_name, variant, example_input):
5765
)
5866
print(
5967
"TURBINE OUTPUT:",
60-
turbine_results.to_host(),
61-
turbine_results.to_host().shape,
62-
turbine_results.to_host().dtype,
68+
turbine_results,
69+
turbine_results.shape,
70+
turbine_results.dtype,
6371
)
6472
if args.compare_vs_torch:
6573
print("generating torch output: ")
@@ -69,9 +77,10 @@ def run_torch_vae(hf_model_name, variant, example_input):
6977
args.hf_model_name, args.vae_variant, example_input.float()
7078
)
7179
print("TORCH OUTPUT:", torch_output, torch_output.shape, torch_output.dtype)
72-
err = utils.largest_error(torch_output, turbine_results)
73-
print("Largest Error: ", err)
74-
assert err < 2e-3
80+
# Allow a small amount of wiggle room for rounding errors (1)
81+
np.testing.assert_allclose(
82+
turbine_results, torch_output, rtol=1, atol=1
83+
)
7584

7685
# TODO: Figure out why we occasionally segfault without unlinking output variables
7786
turbine_results = None

models/turbine_models/custom_models/sd3_inference/text_encoder_impls.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -341,8 +341,10 @@ def __init__(self):
341341
self.clip_g = SDXLClipGTokenizer(clip_tokenizer)
342342
self.t5xxl = T5XXLTokenizer()
343343

344-
def tokenize_with_weights(self, text: str):
344+
def tokenize_with_weights(self, text: str | list[str]):
345345
out = {}
346+
if isinstance(text, list):
347+
text = text[0]
346348
out["g"] = self.clip_g.tokenize_with_weights(text)
347349
out["l"] = self.clip_l.tokenize_with_weights(text)
348350
out["t5xxl"] = self.t5xxl.tokenize_with_weights(text)

models/turbine_models/custom_models/sdxl_inference/unet_runner.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,8 @@ def run_unet(
3131
ireert.asdevicearray(runner.config.device, prompt_embeds),
3232
ireert.asdevicearray(runner.config.device, text_embeds),
3333
ireert.asdevicearray(runner.config.device, time_ids),
34-
ireert.asdevicearray(runner.config.device, guidance_scale),
3534
]
36-
results = runner.ctx.modules.compiled_unet["main"](*inputs)
35+
results = runner.ctx.modules.compiled_unet["run_forward"](*inputs)
3736

3837
return results
3938

@@ -57,7 +56,6 @@ def run_unet_steps(
5756
ireert.asdevicearray(runner.config.device, prompt_embeds),
5857
ireert.asdevicearray(runner.config.device, text_embeds),
5958
ireert.asdevicearray(runner.config.device, time_ids),
60-
ireert.asdevicearray(runner.config.device, (guidance_scale,)),
6159
]
6260
for i, t in tqdm(enumerate(scheduler.timesteps)):
6361
timestep = t
@@ -69,7 +67,7 @@ def run_unet_steps(
6967
inputs[1] = timestep = ireert.asdevicearray(
7068
runner.config.device, (timestep,), dtype="int64"
7169
)
72-
noise_pred = runner.ctx.modules.compiled_unet["main"](*inputs).to_host()
70+
noise_pred = runner.ctx.modules.compiled_unet["run_forward"](*inputs).to_host()
7371
sample = scheduler.step(
7472
torch.from_numpy(noise_pred).cpu(),
7573
timestep,

models/turbine_models/model_runner.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import argparse
22
import sys
33
from iree import runtime as ireert
4-
#from iree.runtime._binding import create_hal_driver
4+
from iree.runtime._binding import create_hal_driver
55

66

77
class vmfbRunner:
@@ -11,14 +11,14 @@ def __init__(self, device, vmfb_path, external_weight_path=None, extra_plugin=No
1111
# If an extra plugin is requested, add a global flag to load the plugin
1212
# and create the driver using the non-caching creation function, as
1313
# the caching creation function may ignore the flag.
14-
# if extra_plugin:
15-
# ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
16-
# haldriver = create_hal_driver(device)
14+
if extra_plugin:
15+
ireert.flags.parse_flags(f"--executable_plugin={extra_plugin}")
16+
haldriver = create_hal_driver(device)
1717

1818
# No plugin requested: create the driver with the caching create
1919
# function.
20-
#else:
21-
haldriver = ireert.get_driver(device)
20+
else:
21+
haldriver = ireert.get_driver(device)
2222
if "://" in device:
2323
try:
2424
device_idx = int(device.split("://")[-1])

0 commit comments

Comments
 (0)