Skip to content

Commit abc3e6e

Browse files
committed
Fixes for cpu schedulers, add split scheduler support to sdxl pipeline
1 parent 9eea857 commit abc3e6e

File tree

6 files changed

+271
-84
lines changed

6 files changed

+271
-84
lines changed

models/turbine_models/custom_models/sd_inference/schedulers.py

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,20 @@ def __init__(self, rt_device, vmfb):
4141
self.runner = vmfbRunner(rt_device, vmfb, None)
4242

4343
def initialize(self, sample):
44-
return self.runner.ctx.modules.compiled_scheduler["run_initialize"](sample)
44+
sample, time_ids, steps, timesteps = self.runner.ctx.modules.compiled_scheduler["run_initialize"](sample)
45+
return sample, time_ids, steps.to_host(), timesteps
4546

4647
def scale_model_input(self, sample, t, timesteps):
4748
return self.runner.ctx.modules.compiled_scheduler["run_scale"](
4849
sample, t, timesteps
4950
)
5051

5152
def step(self, noise_pred, t, sample, guidance_scale, step_index):
53+
print(
54+
noise_pred.to_host()[:,:,0,2],
55+
t,
56+
sample.to_host()[:,:,0,2],
57+
)
5258
return self.runner.ctx.modules.compiled_scheduler["run_step"](
5359
noise_pred, t, sample, guidance_scale, step_index
5460
)
@@ -98,7 +104,7 @@ def initialize(self, sample):
98104
sample.type(self.dtype),
99105
add_time_ids,
100106
step_count,
101-
timesteps.type(self.dtype),
107+
timesteps.type(torch.float32),
102108
)
103109

104110
def prepare_model_input(self, sample, t, timesteps):
@@ -119,15 +125,11 @@ def step(self, noise_pred, t, sample, guidance_scale, i):
119125
noise_pred = noise_preds[0] + guidance_scale * (
120126
noise_preds[1] - noise_preds[0]
121127
)
122-
if self.model.config.skip_prk_steps == True:
123-
sample = self.model.step_plms(noise_pred, t, sample, return_dict=False)[0]
124-
else:
125-
sample = self.model.step(noise_pred, t, sample, return_dict=False)[0]
128+
sample = self.model.step(noise_pred, t, sample, return_dict=False)[0]
126129
return sample.type(self.dtype)
127130

128-
129-
@torch.no_grad()
130131
class SharkSchedulerCPUWrapper:
132+
@torch.no_grad()
131133
def __init__(
132134
self, scheduler, batch_size, num_inference_steps, dest_device, latents_dtype
133135
):
@@ -137,13 +139,16 @@ def __init__(
137139
self.dtype = latents_dtype
138140
self.batch_size = batch_size
139141
self.module.set_timesteps(num_inference_steps)
142+
self.timesteps = self.module.timesteps
140143
self.torch_dtype = (
141144
torch.float32 if latents_dtype == "float32" else torch.float16
142145
)
143146

144147
def initialize(self, sample):
145-
height = sample.shape[2]
146-
width = sample.shape[3]
148+
if isinstance(sample, ireert.DeviceArray):
149+
sample = torch.tensor(sample.to_host(), dtype=torch.float32)
150+
height = sample.shape[2] * 8
151+
width = sample.shape[3] * 8
147152
original_size = (height, width)
148153
target_size = (height, width)
149154
crops_coords_top_left = (0, 0)
@@ -155,10 +160,10 @@ def initialize(self, sample):
155160
self.torch_dtype
156161
)
157162
step_indexes = torch.tensor(len(self.module.timesteps))
158-
timesteps = self.module.timesteps
163+
timesteps = self.timesteps
159164
sample = sample * self.module.init_noise_sigma
165+
print(sample, add_time_ids, step_indexes, timesteps)
160166
add_time_ids = ireert.asdevicearray(self.dest, add_time_ids, self.dtype)
161-
step_indexes = ireert.asdevicearray(self.dest, step_indexes, "int64")
162167
return sample, add_time_ids, step_indexes, timesteps
163168

164169
def scale_model_input(self, sample, t, timesteps):
@@ -167,24 +172,27 @@ def scale_model_input(self, sample, t, timesteps):
167172
t = timesteps[t]
168173
scaled = self.module.scale_model_input(sample, t)
169174
t = ireert.asdevicearray(self.dest, [t], self.dtype)
175+
scaled = ireert.asdevicearray(self.dest, scaled, self.dtype)
170176
return scaled, t
171177

172-
def step(self, latents, t, sample, guidance_scale, i):
173-
if isinstance(latents, ireert.DeviceArray):
174-
latents = torch.tensor(latents.to_host())
178+
def step(self, noise_pred, t, latents, guidance_scale, i):
175179
if isinstance(t, ireert.DeviceArray):
176-
t = self.module.timesteps[i]
177-
if isinstance(sample, ireert.DeviceArray):
178-
sample = torch.tensor(sample.to_host())
180+
t = torch.tensor(t.to_host())
181+
if isinstance(guidance_scale, ireert.DeviceArray):
182+
guidance_scale = torch.tensor(guidance_scale.to_host())
183+
noise_pred = torch.tensor(noise_pred.to_host())
179184
if self.do_classifier_free_guidance:
180-
noise_preds = latents.chunk(2)
181-
latents = noise_preds[0] + guidance_scale * (
182-
noise_preds[1] - noise_preds[0]
183-
)
185+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
186+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
187+
print(
188+
noise_pred[:,:,0,2],
189+
t,
190+
latents[:,:,0,2],
191+
)
184192
return self.module.step(
185-
latents,
193+
noise_pred,
186194
t,
187-
sample,
195+
latents,
188196
return_dict=False,
189197
)[0]
190198

@@ -212,19 +220,23 @@ def export_scheduler_model(
212220
scheduler_module = SchedulingModel(
213221
hf_model_name, scheduler, height, width, batch_size, num_inference_steps, dtype
214222
)
215-
vmfb_names = [
216-
scheduler_id + "Scheduler",
217-
f"bs{batch_size}",
218-
f"{height}x{width}",
219-
precision,
220-
str(num_inference_steps),
221-
target_triple,
222-
]
223-
vmfb_name = "_".join(vmfb_names)
224-
225223
if pipeline_dir:
224+
vmfb_names = [
225+
scheduler_id + "Scheduler",
226+
str(num_inference_steps),
227+
]
228+
vmfb_name = "_".join(vmfb_names)
226229
safe_name = os.path.join(pipeline_dir, vmfb_name)
227230
else:
231+
vmfb_names = [
232+
scheduler_id + "Scheduler",
233+
f"bs{batch_size}",
234+
f"{height}x{width}",
235+
precision,
236+
str(num_inference_steps),
237+
target_triple,
238+
]
239+
vmfb_name = "_".join(vmfb_names)
228240
safe_name = utils.create_safe_name(hf_model_name, "_" + vmfb_name)
229241

230242
if input_mlir:
@@ -261,7 +273,7 @@ def export_scheduler_model(
261273
example_prep_args = (
262274
torch.empty(sample, dtype=dtype),
263275
torch.empty(1, dtype=torch.int64),
264-
torch.empty([19], dtype=dtype),
276+
torch.empty([19], dtype=torch.float32),
265277
)
266278
timesteps = torch.export.Dim("timesteps")
267279
prep_dynamic_args = {

models/turbine_models/custom_models/sd_inference/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def compile_to_vmfb(
136136
"--iree-hal-target-backends=rocm",
137137
"--iree-rocm-target-chip=" + target_triple,
138138
"--iree-vm-bytecode-module-output-format=flatbuffer-binary",
139-
"--iree-flow-inline-constants-max-byte-length=1",
140139
]
141140
)
142141
if target_triple == "gfx942":

models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ def is_valid_file(arg):
130130
help="Use a decoupled unet and scheduler for better QOL.",
131131
)
132132

133+
p.add_argument(
134+
"--cpu_scheduling",
135+
default=False,
136+
action="store_true",
137+
help="Run scheduling on torch cpu (will be slower due to data movement costs).",
138+
)
139+
133140
p.add_argument(
134141
"--external_weight_file",
135142
type=str,

0 commit comments

Comments
 (0)