Skip to content

Commit f6ab086

Browse files
committed
Multi-device support (SDXL)
1 parent 9656135 commit f6ab086

File tree

7 files changed

+236
-90
lines changed

7 files changed

+236
-90
lines changed

models/turbine_models/custom_models/sd3_inference/sd3_cmd_opts.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ def is_valid_file(arg):
177177
help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.",
178178
)
179179

180+
p.add_argument(
181+
"--npu_delegate_path",
182+
type=str,
183+
default=None,
184+
help="Path to npu executable plugin .dll for running VAE on NPU.",
185+
)
186+
187+
180188
p.add_argument(
181189
"--clip_device",
182190
default=None,

models/turbine_models/custom_models/sd3_inference/sd3_pipeline.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -774,13 +774,18 @@ def run_diffusers_cpu(
774774
args.vae_decomp_attn,
775775
custom_vae=None,
776776
cpu_scheduling=args.cpu_scheduling,
777+
vae_precision=args.vae_precision,
777778
)
778779
vmfbs, weights = sd3_pipe.check_prepared(mlirs, vmfbs, weights)
779780
if args.cpu_scheduling:
780781
vmfbs.pop("scheduler")
781782
weights.pop("scheduler")
783+
if args.npu_delegate_path:
784+
extra_device_args = {"npu_delegate_path": args.npu_delegate_path}
785+
else:
786+
extra_device_args = {}
782787
sd3_pipe.load_pipeline(
783-
vmfbs, weights, args.compiled_pipeline, args.split_scheduler
788+
vmfbs, weights, args.compiled_pipeline, args.split_scheduler, extra_device_args=extra_device_args
784789
)
785790
sd3_pipe.generate_images(
786791
args.prompt,

models/turbine_models/custom_models/sd3_inference/sd3_vae.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,9 @@ def export_vae_model(
9090
)
9191
return vmfb_path
9292

93+
if device == "cpu":
94+
decomp_attn = True
95+
9396
if dtype == torch.float16:
9497
vae_model = vae_model.half()
9598
mapper = {}

models/turbine_models/custom_models/sd_inference/schedulers.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def initialize(self, sample):
160160
step_indexes = torch.tensor(len(self.module.timesteps))
161161
timesteps = self.timesteps
162162
sample = sample * self.module.init_noise_sigma
163-
print(sample, add_time_ids, step_indexes, timesteps)
164163
add_time_ids = ireert.asdevicearray(self.dest, add_time_ids, self.dtype)
165164
return sample, add_time_ids, step_indexes, timesteps
166165

@@ -184,11 +183,6 @@ def step(self, noise_pred, t, latents, guidance_scale, i):
184183
noise_pred = noise_pred_uncond + guidance_scale * (
185184
noise_pred_text - noise_pred_uncond
186185
)
187-
print(
188-
noise_pred[:, :, 0, 2],
189-
t,
190-
latents[:, :, 0, 2],
191-
)
192186
return self.module.step(
193187
noise_pred,
194188
t,

models/turbine_models/custom_models/sdxl_inference/sdxl_cmd_opts.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def is_valid_file(arg):
125125

126126
p.add_argument(
127127
"--split_scheduler",
128-
default=False,
128+
default=True,
129129
action="store_true",
130130
help="Use a decoupled unet and scheduler for better QOL.",
131131
)
@@ -158,6 +158,62 @@ def is_valid_file(arg):
158158
help="Do one-shot inference from tokens to image in a shrink-wrapped pipeline binary.",
159159
)
160160

161+
p.add_argument(
162+
"--vae_precision",
163+
type=str,
164+
default="fp16",
165+
help="Precision of VAE weights and graph.",
166+
)
167+
168+
p.add_argument(
169+
"--npu_delegate_path",
170+
type=str,
171+
default=None,
172+
help="Path to npu executable plugin .dll for running VAE on NPU.",
173+
)
174+
175+
p.add_argument(
176+
"--clip_device",
177+
default=None,
178+
type=str,
179+
help="Device to run CLIP on. If None, defaults to the device specified in args.device.",
180+
)
181+
182+
p.add_argument(
183+
"--unet_device",
184+
default=None,
185+
type=str,
186+
help="Device to run unet on. If None, defaults to the device specified in args.device.",
187+
)
188+
189+
p.add_argument(
190+
"--vae_device",
191+
default=None,
192+
type=str,
193+
help="Device to run VAE on. If None, defaults to the device specified in args.device.",
194+
)
195+
196+
p.add_argument(
197+
"--clip_target",
198+
default=None,
199+
type=str,
200+
help="IREE target for CLIP compilation. If None, defaults to the target specified by --iree_target_triple.",
201+
)
202+
203+
p.add_argument(
204+
"--unet_target",
205+
default=None,
206+
type=str,
207+
help="IREE target for unet compilation. If None, defaults to the target specified by --iree_target_triple.",
208+
)
209+
210+
p.add_argument(
211+
"--vae_target",
212+
default=None,
213+
type=str,
214+
help="IREE target for vae compilation. If None, defaults to the target specified by --iree_target_triple.",
215+
)
216+
161217
##############################################################################
162218
# SDXL Modelling Options
163219
# These options are used to control model defining parameters for SDXL.

0 commit comments

Comments
 (0)