Skip to content

Commit 6045016

Browse files
committed
Start fixing tests
1 parent abc3e6e commit 6045016

File tree

6 files changed

+161
-193
lines changed

6 files changed

+161
-193
lines changed

models/turbine_models/custom_models/sd_inference/schedulers.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def __init__(self, rt_device, vmfb):
4141
self.runner = vmfbRunner(rt_device, vmfb, None)
4242

4343
def initialize(self, sample):
44-
sample, time_ids, steps, timesteps = self.runner.ctx.modules.compiled_scheduler["run_initialize"](sample)
44+
sample, time_ids, steps, timesteps = self.runner.ctx.modules.compiled_scheduler[
45+
"run_initialize"
46+
](sample)
4547
return sample, time_ids, steps.to_host(), timesteps
4648

4749
def scale_model_input(self, sample, t, timesteps):
@@ -50,11 +52,6 @@ def scale_model_input(self, sample, t, timesteps):
5052
)
5153

5254
def step(self, noise_pred, t, sample, guidance_scale, step_index):
53-
print(
54-
noise_pred.to_host()[:,:,0,2],
55-
t,
56-
sample.to_host()[:,:,0,2],
57-
)
5855
return self.runner.ctx.modules.compiled_scheduler["run_step"](
5956
noise_pred, t, sample, guidance_scale, step_index
6057
)
@@ -128,6 +125,7 @@ def step(self, noise_pred, t, sample, guidance_scale, i):
128125
sample = self.model.step(noise_pred, t, sample, return_dict=False)[0]
129126
return sample.type(self.dtype)
130127

128+
131129
class SharkSchedulerCPUWrapper:
132130
@torch.no_grad()
133131
def __init__(
@@ -183,11 +181,13 @@ def step(self, noise_pred, t, latents, guidance_scale, i):
183181
noise_pred = torch.tensor(noise_pred.to_host())
184182
if self.do_classifier_free_guidance:
185183
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
186-
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
184+
noise_pred = noise_pred_uncond + guidance_scale * (
185+
noise_pred_text - noise_pred_uncond
186+
)
187187
print(
188-
noise_pred[:,:,0,2],
188+
noise_pred[:, :, 0, 2],
189189
t,
190-
latents[:,:,0,2],
190+
latents[:, :, 0, 2],
191191
)
192192
return self.module.step(
193193
noise_pred,

models/turbine_models/custom_models/sdxl_inference/sdxl_compiled_pipeline.py

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@ def check_prepared(
132132
vmfbs[submodel] = vmfb
133133
if weights[submodel] is None:
134134
weights[submodel] = weight
135-
elif weights[submodel] is None and not any(x in submodel for x in ["pipeline", "scheduler"]):
135+
elif weights[submodel] is None and not any(
136+
x in submodel for x in ["pipeline", "scheduler"]
137+
):
136138
_, weight = self.export_submodel(submodel, weights_only=True)
137139
weights[submodel] = weight
138140
ready, vmfbs, weights = self.is_prepared(vmfbs, weights)
@@ -157,7 +159,7 @@ def is_prepared(self, vmfbs, weights):
157159
default_filepath = os.path.join(self.pipeline_dir, val + ".vmfb")
158160
elif key == "scheduler":
159161
val = None
160-
default_filepath=None
162+
default_filepath = None
161163
continue
162164
else:
163165
val = vmfbs[key]
@@ -494,7 +496,9 @@ def load_pipeline(
494496
)
495497
else:
496498
print("\n[LOG] Running scheduler on CPU. This will affect performance.")
497-
scheduler = schedulers.get_scheduler(args.hf_model_name, args.scheduler_id)
499+
scheduler = schedulers.get_scheduler(
500+
args.hf_model_name, args.scheduler_id
501+
)
498502
runners["scheduler"] = schedulers.SharkSchedulerCPUWrapper(
499503
scheduler,
500504
args.batch_size,
@@ -535,7 +539,9 @@ def load_pipeline(
535539
],
536540
)
537541
pipe_loaded = time.time()
538-
print("\n[LOG] Compiled Pipeline loaded in ", pipe_loaded - load_start, "sec")
542+
print(
543+
"\n[LOG] Compiled Pipeline loaded in ", pipe_loaded - load_start, "sec"
544+
)
539545

540546
else:
541547
runners["pipe"] = vmfbRunner(
@@ -556,7 +562,9 @@ def load_pipeline(
556562
runners["vae_decode"] = runners["pipe"]
557563
runners["prompt_encoder"] = runners["pipe"]
558564
pipe_loaded = time.time()
559-
print("\n[LOG] Compiled Pipeline loaded in ", pipe_loaded - load_start, "sec")
565+
print(
566+
"\n[LOG] Compiled Pipeline loaded in ", pipe_loaded - load_start, "sec"
567+
)
560568
tok_start = time.time()
561569
runners["tokenizer_1"] = CLIPTokenizer.from_pretrained(
562570
self.hf_model_name,
@@ -704,11 +712,17 @@ def generate_images(
704712
for i in range(batch_count):
705713
unet_start = time.time()
706714
if self.runners["scheduler"]:
707-
sample, time_ids, steps, timesteps = self.runners["scheduler"].initialize(samples[i])
715+
sample, time_ids, steps, timesteps = self.runners[
716+
"scheduler"
717+
].initialize(samples[i])
708718
iree_inputs = [
709719
sample,
710-
ireert.asdevicearray(self.runners["pipe"].config.device, prompt_embeds),
711-
ireert.asdevicearray(self.runners["pipe"].config.device, add_text_embeds),
720+
ireert.asdevicearray(
721+
self.runners["pipe"].config.device, prompt_embeds
722+
),
723+
ireert.asdevicearray(
724+
self.runners["pipe"].config.device, add_text_embeds
725+
),
712726
time_ids,
713727
None,
714728
]
@@ -717,13 +731,19 @@ def generate_images(
717731
if self.cpu_scheduling:
718732
step_index = s
719733
else:
720-
step_index = ireert.asdevicearray(self.runners["scheduler"].runner.config.device, torch.tensor([s]), "int64")
734+
step_index = ireert.asdevicearray(
735+
self.runners["scheduler"].runner.config.device,
736+
torch.tensor([s]),
737+
"int64",
738+
)
721739
latents, t = self.runners["scheduler"].scale_model_input(
722740
sample,
723741
step_index,
724742
timesteps,
725743
)
726-
noise_pred = self.runners["pipe"].ctx.modules.compiled_unet["run_forward"](
744+
noise_pred = self.runners["pipe"].ctx.modules.compiled_unet[
745+
"run_forward"
746+
](
727747
latents,
728748
t,
729749
iree_inputs[1],
@@ -738,9 +758,13 @@ def generate_images(
738758
step_index,
739759
)
740760
if isinstance(sample, torch.Tensor):
741-
#TODO: pipe an option for vae_dtype
761+
# TODO: pipe an option for vae_dtype
742762
vae_dtype = "float32" if self.precision == "fp32" else "float16"
743-
latents = ireert.asdevicearray(self.runners["vae_decode"].config.device, sample, dtype=vae_dtype)
763+
latents = ireert.asdevicearray(
764+
self.runners["vae_decode"].config.device,
765+
sample,
766+
dtype=vae_dtype,
767+
)
744768
else:
745769
latents = sample
746770
else:
@@ -833,6 +857,7 @@ def numpy_to_pil_image(images):
833857

834858
if __name__ == "__main__":
835859
from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args
860+
836861
map = empty_pipe_dict
837862
if args.split_scheduler:
838863
map["scheduler"] = None
@@ -894,13 +919,15 @@ def numpy_to_pil_image(images):
894919
args.external_weights_dir,
895920
args.external_weights,
896921
args.vae_decomp_attn,
897-
custom_vae = None,
898-
cpu_scheduling = args.cpu_scheduling,
922+
custom_vae=None,
923+
cpu_scheduling=args.cpu_scheduling,
899924
)
900925
vmfbs, weights = sdxl_pipe.check_prepared(mlirs, vmfbs, weights)
901926
if args.cpu_scheduling:
902927
vmfbs["scheduler"] = None
903-
sdxl_pipe.load_pipeline(vmfbs, weights, args.rt_device, args.compiled_pipeline, args.split_scheduler)
928+
sdxl_pipe.load_pipeline(
929+
vmfbs, weights, args.rt_device, args.compiled_pipeline, args.split_scheduler
930+
)
904931
sdxl_pipe.generate_images(
905932
args.prompt,
906933
args.negative_prompt,

models/turbine_models/custom_models/sdxl_inference/sdxl_prompt_encoder_runner.py

Lines changed: 47 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -5,58 +5,18 @@
55
import numpy as np
66

77

8-
def run_torch_clip(hf_model_name, hf_auth_token, prompt, max_length=64):
9-
# TODO: Integrate with HFTransformerBuilder
10-
from turbine_models.custom_models.sdxl_inference.clip import ClipModel
11-
12-
model_1 = ClipModel(hf_model_name, hf_auth_token, index=1)
13-
model_2 = ClipModel(hf_model_name, hf_auth_token, index=2)
14-
tokenizer_1 = CLIPTokenizer.from_pretrained(
15-
hf_model_name,
16-
subfolder="tokenizer",
17-
token=hf_auth_token,
18-
)
19-
tokenizer_2 = CLIPTokenizer.from_pretrained(
20-
hf_model_name,
21-
subfolder="tokenizer_2",
22-
token=hf_auth_token,
23-
)
24-
text_input_1 = tokenizer_1(
25-
prompt,
26-
padding="max_length",
27-
max_length=max_length,
28-
truncation=True,
29-
return_tensors="pt",
30-
)
31-
text_input_2 = tokenizer_2(
32-
prompt,
33-
padding="max_length",
34-
max_length=max_length,
35-
truncation=True,
36-
return_tensors="pt",
37-
)
38-
example_input_1 = text_input_1.input_ids
39-
example_input_2 = text_input_2.input_ids
40-
41-
results_1 = model_1.forward(example_input_1)
42-
results_2 = model_2.forward(example_input_2)
43-
np_torch_output_1 = results_1[0].detach().cpu().numpy().astype(np.float16)
44-
np_torch_output_2 = results_2[0].detach().cpu().numpy().astype(np.float16)
45-
return np_torch_output_1, np_torch_output_2
46-
47-
488
def run_prompt_encoder(
49-
args,
9+
vmfb_path,
10+
device,
11+
external_weight_path,
5012
input_ids,
5113
uncond_input_ids,
5214
):
53-
prompt_encoder_runner = vmfbRunner(
54-
args.device, args.vmfb_path, args.external_weight_path
55-
)
56-
np.save("input0.npy", input_ids[0].numpy())
57-
np.save("input1.npy", input_ids[1].numpy())
58-
np.save("input2.npy", uncond_input_ids[0].numpy())
59-
np.save("input3.npy", uncond_input_ids[1].numpy())
15+
prompt_encoder_runner = vmfbRunner(device, vmfb_path, external_weight_path)
16+
# np.save("input0.npy", input_ids[0].numpy())
17+
# np.save("input1.npy", input_ids[1].numpy())
18+
# np.save("input2.npy", uncond_input_ids[0].numpy())
19+
# np.save("input3.npy", uncond_input_ids[1].numpy())
6020
prompt_encoder_inputs = [
6121
ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[0]),
6222
ireert.asdevicearray(prompt_encoder_runner.config.device, input_ids[1]),
@@ -66,40 +26,36 @@ def run_prompt_encoder(
6626
encoded_outputs = prompt_encoder_runner.ctx.modules.compiled_clip["encode_prompts"](
6727
*prompt_encoder_inputs
6828
)
29+
for i in encoded_outputs:
30+
i = i.to_host()
6931
del prompt_encoder_inputs
7032
return encoded_outputs
7133

7234

73-
if __name__ == "__main__":
74-
from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args
75-
76-
tokenizer_1 = CLIPTokenizer.from_pretrained(
77-
args.hf_model_name,
78-
subfolder="tokenizer",
79-
token=args.hf_auth_token,
80-
)
81-
tokenizer_2 = CLIPTokenizer.from_pretrained(
82-
args.hf_model_name,
83-
subfolder="tokenizer_2",
84-
token=args.hf_auth_token,
85-
)
35+
def run_tokenize(
36+
tokenizer_1,
37+
tokenizer_2,
38+
prompt,
39+
negative_prompt,
40+
max_length=64,
41+
):
8642
text_input_ids_list = []
8743
uncond_input_ids_list = []
8844

8945
# Tokenize prompt and negative prompt.
9046
tokenizers = [tokenizer_1, tokenizer_2]
9147
for tokenizer in tokenizers:
9248
text_inputs = tokenizer(
93-
args.prompt,
49+
prompt,
9450
padding="max_length",
95-
max_length=args.max_length,
51+
max_length=max_length,
9652
truncation=True,
9753
return_tensors="pt",
9854
)
9955
uncond_input = tokenizer(
100-
args.negative_prompt,
56+
negative_prompt,
10157
padding="max_length",
102-
max_length=args.max_length,
58+
max_length=max_length,
10359
truncation=True,
10460
return_tensors="pt",
10561
)
@@ -108,9 +64,34 @@ def run_prompt_encoder(
10864

10965
text_input_ids_list.extend([text_input_ids])
11066
uncond_input_ids_list.extend([uncond_input_ids])
67+
return text_input_ids_list, uncond_input_ids_list
68+
11169

70+
if __name__ == "__main__":
71+
from turbine_models.custom_models.sdxl_inference.sdxl_cmd_opts import args
72+
73+
tokenizer_1 = CLIPTokenizer.from_pretrained(
74+
args.hf_model_name,
75+
subfolder="tokenizer",
76+
token=args.hf_auth_token,
77+
)
78+
tokenizer_2 = CLIPTokenizer.from_pretrained(
79+
args.hf_model_name,
80+
subfolder="tokenizer_2",
81+
token=args.hf_auth_token,
82+
)
83+
84+
text_input_ids_list, uncond_input_ids_list = run_tokenize(
85+
tokenizer_1,
86+
tokenizer_2,
87+
args.prompt,
88+
args.negative_prompt,
89+
args.max_length,
90+
)
11291
turbine_output1, turbine_output2 = run_prompt_encoder(
113-
args,
92+
args.vmfb_path,
93+
args.rt_device,
94+
args.external_weight_path,
11495
text_input_ids_list,
11596
uncond_input_ids_list,
11697
)

models/turbine_models/custom_models/sdxl_inference/unet.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,9 +94,7 @@ def export_unet_model(
9494
weights_only=False,
9595
):
9696
if pipeline_dir:
97-
safe_name = os.path.join(
98-
pipeline_dir, f"unet"
99-
)
97+
safe_name = os.path.join(pipeline_dir, f"unet")
10098
else:
10199
safe_name = utils.create_safe_name(
102100
hf_model_name,

0 commit comments

Comments
 (0)