Skip to content

Commit 7b861a7

Browse files
committed
Fix formatting
1 parent f6ab086 commit 7b861a7

File tree

7 files changed

+83
-36
lines changed

7 files changed

+83
-36
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_mmdit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ def forward(
5252
return_dict=False,
5353
)[0]
5454
return noise_pred
55-
55+
56+
5657
class MMDiTAttention(torch.nn.Module):
5758
def __init__(
58-
self,
59+
self,
5960
):
6061
super().__init__()
6162

@@ -84,7 +85,7 @@ def export_attn(
8485

8586
if dtype == torch.float16:
8687
attn_module = attn_module.half()
87-
88+
8889
example_qkv = [
8990
torch.empty(qkv_shape, dtype=dtype),
9091
torch.empty(qkv_shape, dtype=dtype),
@@ -134,6 +135,7 @@ class CompiledAttn(CompiledModule):
134135
)
135136
return vmfb_path
136137

138+
137139
@torch.no_grad()
138140
def export_mmdit_model(
139141
mmdit_model,

models/turbine_models/custom_models/sd3_inference/sd3_mmdit_runner.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def run_diffusers_mmdit(
5757

5858
return noise_pred.numpy()
5959

60+
6061
def run_attn_turbine(q, k, v, args):
6162
attn_runner = vmfbRunner(
6263
args.device,
@@ -73,6 +74,7 @@ def run_attn_turbine(q, k, v, args):
7374
).to_host()
7475
return attn_output
7576

77+
7678
@torch.no_grad()
7779
def run_attn_torch(q, k, v, args):
7880
from turbine_models.custom_models.sd3_inference.sd3_mmdit import MMDiTAttention
@@ -86,21 +88,27 @@ def run_attn_torch(q, k, v, args):
8688

8789
return attn_output.numpy()
8890

91+
8992
def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
9093
if not np.allclose(turbine_output, torch_output, rtol=4e-2, atol=4e-2):
9194
if turbine_output.ndim > 0:
9295
orig_dim = dim
9396
for idx, i in enumerate(torch_output):
9497
dim = [*orig_dim, idx]
9598
try:
96-
np.testing.assert_allclose(turbine_output[idx], torch_output[idx], rtol=4e-2, atol=4e-2)
99+
np.testing.assert_allclose(
100+
turbine_output[idx], torch_output[idx], rtol=4e-2, atol=4e-2
101+
)
97102
except Exception as e:
98103
err = np.abs(turbine_output[idx] - torch_output[idx])
99104
failed_dims.append(dim)
100105
errs.append([err, turbine_output[idx], torch_output[idx]])
101-
failed_dims, errs = find_errs(turbine_output[idx], torch_output[idx], dim, failed_dims, errs)
106+
failed_dims, errs = find_errs(
107+
turbine_output[idx], torch_output[idx], dim, failed_dims, errs
108+
)
102109
return (failed_dims, errs)
103110

111+
104112
if __name__ == "__main__":
105113
from turbine_models.custom_models.sd3_inference.sd3_cmd_opts import args
106114
import numpy as np
@@ -137,8 +145,8 @@ def find_errs(turbine_output, torch_output, dim=[], failed_dims=[], errs=[]):
137145
print("Torch output: ", errs[idx][2])
138146
print(torch_output.shape)
139147
exit()
140-
141-
batch_size = args.batch_size * 2 #do classifier free guidance
148+
149+
batch_size = args.batch_size * 2 # do classifier free guidance
142150
hidden_states = torch.randn(
143151
(batch_size, 16, args.height // 8, args.width // 8), dtype=dtype
144152
)

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -75,38 +75,42 @@ def __init__(
7575
self.num_inference_steps = num_inference_steps
7676
self.devices = {}
7777
if isinstance(device, dict):
78-
assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings."
78+
assert isinstance(
79+
iree_target_triple, dict
80+
), "Device and target triple must be both dicts or both strings."
7981
self.devices["clip"] = {
8082
"device": device["clip"],
8183
"driver": utils.iree_device_map(device["clip"]),
82-
"target": iree_target_triple["clip"]
84+
"target": iree_target_triple["clip"],
8385
}
8486
self.devices["mmdit"] = {
8587
"device": device["mmdit"],
8688
"driver": utils.iree_device_map(device["mmdit"]),
87-
"target": iree_target_triple["mmdit"]
89+
"target": iree_target_triple["mmdit"],
8890
}
8991
self.devices["vae"] = {
9092
"device": device["vae"],
9193
"driver": utils.iree_device_map(device["vae"]),
92-
"target": iree_target_triple["vae"]
94+
"target": iree_target_triple["vae"],
9395
}
9496
else:
95-
assert isinstance(iree_target_triple, str), "Device and target triple must be both dicts or both strings."
97+
assert isinstance(
98+
iree_target_triple, str
99+
), "Device and target triple must be both dicts or both strings."
96100
self.devices["clip"] = {
97101
"device": device,
98102
"driver": utils.iree_device_map(device),
99-
"target": iree_target_triple
103+
"target": iree_target_triple,
100104
}
101105
self.devices["mmdit"] = {
102106
"device": device,
103107
"driver": utils.iree_device_map(device),
104-
"target": iree_target_triple
108+
"target": iree_target_triple,
105109
}
106110
self.devices["vae"] = {
107111
"device": device,
108112
"driver": utils.iree_device_map(device),
109-
"target": iree_target_triple
113+
"target": iree_target_triple,
110114
}
111115
self.iree_target_triple = iree_target_triple
112116
self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
@@ -645,7 +649,8 @@ def generate_images(
645649
image.save(img_path)
646650
print(img_path, "saved")
647651
return
648-
652+
653+
649654
def run_diffusers_cpu(
650655
hf_model_name,
651656
prompt,
@@ -658,7 +663,9 @@ def run_diffusers_cpu(
658663
):
659664
from diffusers import StableDiffusion3Pipeline
660665

661-
pipe = StableDiffusion3Pipeline.from_pretrained(hf_model_name, torch_dtype=torch.float32)
666+
pipe = StableDiffusion3Pipeline.from_pretrained(
667+
hf_model_name, torch_dtype=torch.float32
668+
)
662669
pipe = pipe.to("cpu")
663670
generator = torch.Generator().manual_seed(int(seed))
664671

@@ -703,7 +710,9 @@ def run_diffusers_cpu(
703710
x for x in [args.clip_target, args.mmdit_target, args.vae_target]
704711
), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels."
705712
args.device = "hybrid"
706-
args.iree_target_triple = "_".join([args.clip_target, args.mmdit_target, args.vae_target])
713+
args.iree_target_triple = "_".join(
714+
[args.clip_target, args.mmdit_target, args.vae_target]
715+
)
707716
else:
708717
args.clip_device = args.device
709718
args.mmdit_device = args.device
@@ -785,7 +794,11 @@ def run_diffusers_cpu(
785794
else:
786795
extra_device_args = {}
787796
sd3_pipe.load_pipeline(
788-
vmfbs, weights, args.compiled_pipeline, args.split_scheduler, extra_device_args=extra_device_args
797+
vmfbs,
798+
weights,
799+
args.compiled_pipeline,
800+
args.split_scheduler,
801+
extra_device_args=extra_device_args,
789802
)
790803
sd3_pipe.generate_images(
791804
args.prompt,

models/turbine_models/custom_models/sd3_inference/sd3_schedulers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(
6666
def initialize(self, sample):
6767
step_count = torch.tensor(len(self.timesteps))
6868
timesteps = self.model.timesteps
69-
#ops.trace_tensor("sample", sample[:,:,0,0])
69+
# ops.trace_tensor("sample", sample[:,:,0,0])
7070
return (
7171
sample,
7272
step_count,
@@ -93,6 +93,7 @@ def step(self, noise_pred, t, sample, guidance_scale, i):
9393
sample = self.model.step(noise_pred, t, sample, return_dict=False)[0]
9494
return sample.type(self.dtype)
9595

96+
9697
# Wraps a diffusers scheduler running on native pytorch+cpu.
9798
# This allows us to use it interchangeably with compiled schedulers in our pipeline(s).
9899
class TorchCPUFlowSchedulerCompat:

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@
9292
],
9393
}
9494

95+
9596
def iree_device_map(device):
9697
uri_parts = device.split("://", 2)
9798
iree_driver = (
@@ -106,6 +107,7 @@ def iree_device_map(device):
106107
else:
107108
return f"{iree_driver}://{uri_parts[1]}"
108109

110+
109111
def compile_to_vmfb(
110112
module_str,
111113
device,

models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py

Lines changed: 35 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -78,38 +78,42 @@ def __init__(
7878
self.num_inference_steps = num_inference_steps
7979
self.devices = {}
8080
if isinstance(device, dict):
81-
assert isinstance(iree_target_triple, dict), "Device and target triple must be both dicts or both strings."
81+
assert isinstance(
82+
iree_target_triple, dict
83+
), "Device and target triple must be both dicts or both strings."
8284
self.devices["clip"] = {
8385
"device": device["clip"],
8486
"driver": utils.iree_device_map(device["clip"]),
85-
"target": iree_target_triple["clip"]
87+
"target": iree_target_triple["clip"],
8688
}
8789
self.devices["unet"] = {
8890
"device": device["unet"],
8991
"driver": utils.iree_device_map(device["unet"]),
90-
"target": iree_target_triple["unet"]
92+
"target": iree_target_triple["unet"],
9193
}
9294
self.devices["vae"] = {
9395
"device": device["vae"],
9496
"driver": utils.iree_device_map(device["vae"]),
95-
"target": iree_target_triple["vae"]
97+
"target": iree_target_triple["vae"],
9698
}
9799
else:
98-
assert isinstance(iree_target_triple, str), "Device and target triple must be both dicts or both strings."
100+
assert isinstance(
101+
iree_target_triple, str
102+
), "Device and target triple must be both dicts or both strings."
99103
self.devices["clip"] = {
100104
"device": device,
101105
"driver": utils.iree_device_map(device),
102-
"target": iree_target_triple
106+
"target": iree_target_triple,
103107
}
104108
self.devices["unet"] = {
105109
"device": device,
106110
"driver": utils.iree_device_map(device),
107-
"target": iree_target_triple
111+
"target": iree_target_triple,
108112
}
109113
self.devices["vae"] = {
110114
"device": device,
111115
"driver": utils.iree_device_map(device),
112-
"target": iree_target_triple
116+
"target": iree_target_triple,
113117
}
114118
self.ireec_flags = ireec_flags if ireec_flags else EMPTY_FLAGS
115119
self.attn_spec = attn_spec
@@ -548,8 +552,14 @@ def load_pipeline(
548552
clip_loaded = time.time()
549553
print("\n[LOG] CLIP loaded in ", clip_loaded - vae_loaded, "sec")
550554
elif compiled_pipeline:
551-
assert self.devices["unet"]["device"] == self.devices["clip"]["device"] == self.devices["vae"]["device"], "Compiled pipeline requires all submodels to be on the same device."
552-
assert self.precision == self.vae_precision, "Compiled pipeline requires all submodels to have the same precision for now."
555+
assert (
556+
self.devices["unet"]["device"]
557+
== self.devices["clip"]["device"]
558+
== self.devices["vae"]["device"]
559+
), "Compiled pipeline requires all submodels to be on the same device."
560+
assert (
561+
self.precision == self.vae_precision
562+
), "Compiled pipeline requires all submodels to have the same precision for now."
553563
runners["pipe"] = vmfbRunner(
554564
self.devices["unet"]["driver"],
555565
[
@@ -796,9 +806,14 @@ def generate_images(
796806
latents = self.runners["pipe"].ctx.modules.sdxl_compiled_pipeline[
797807
"produce_image_latents"
798808
](samples[i], prompt_embeds, add_text_embeds, guidance_scale)
799-
if self.devices["unet"]["driver"] != self.devices["vae"]["driver"] or self.precision != self.vae_precision:
809+
if (
810+
self.devices["unet"]["driver"] != self.devices["vae"]["driver"]
811+
or self.precision != self.vae_precision
812+
):
800813
latents = ireert.asdevicearray(
801-
self.runners["vae_decode"].config.device, latents.to_host(), dtype=self.vae_dtype
814+
self.runners["vae_decode"].config.device,
815+
latents.to_host(),
816+
dtype=self.vae_dtype,
802817
)
803818
vae_start = time.time()
804819
vae_out = self.runners["vae_decode"].ctx.modules.compiled_vae["main"](
@@ -906,7 +921,9 @@ def numpy_to_pil_image(images):
906921
x for x in [args.clip_target, args.unet_target, args.vae_target]
907922
), "Please specify target triple for all submodels or pass --iree_target_triple for all submodels."
908923
args.device = "hybrid"
909-
args.iree_target_triple = "_".join([args.clip_target, args.unet_target, args.vae_target])
924+
args.iree_target_triple = "_".join(
925+
[args.clip_target, args.unet_target, args.vae_target]
926+
)
910927
else:
911928
args.clip_device = args.device
912929
args.unet_device = args.device
@@ -987,7 +1004,11 @@ def numpy_to_pil_image(images):
9871004
else:
9881005
extra_device_args = {}
9891006
sdxl_pipe.load_pipeline(
990-
vmfbs, weights, args.compiled_pipeline, args.split_scheduler, extra_device_args,
1007+
vmfbs,
1008+
weights,
1009+
args.compiled_pipeline,
1010+
args.split_scheduler,
1011+
extra_device_args,
9911012
)
9921013
sdxl_pipe.generate_images(
9931014
args.prompt,

models/turbine_models/custom_models/sdxl_inference/vae.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def export_vae_model(
107107

108108
if device == "cpu":
109109
decomp_attn = True
110-
110+
111111
dtype = torch.float16 if precision == "fp16" else torch.float32
112112
if precision == "fp16":
113113
vae_model = vae_model.half()
@@ -119,7 +119,7 @@ def export_vae_model(
119119
)
120120
if weights_only:
121121
return external_weight_path
122-
122+
123123
input_image_shape = (height, width, 3)
124124
input_latents_shape = (batch_size, 4, height // 8, width // 8)
125125
encode_args = [

0 commit comments

Comments
 (0)