Skip to content

Commit a0e4792

Browse files
authored
(SD) Add benchmark option and add a printer. (#773)
usage: `--benchmark=all` `--benchmark=unet` `--benchmark=clip,vae` `--verbose`
1 parent eb61c14 commit a0e4792

File tree

3 files changed

+93
-17
lines changed

3 files changed

+93
-17
lines changed

models/turbine_models/custom_models/pipeline_base.py

Lines changed: 64 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,15 @@ class PipelineComponent:
8383
This aims to make new pipelines and execution modes easier to write, manage, and debug.
8484
"""
8585

86-
def __init__(self, dest_type="devicearray", dest_dtype="float16"):
86+
def __init__(
87+
self, printer, dest_type="devicearray", dest_dtype="float16", benchmark=False
88+
):
8789
self.runner = None
8890
self.module_name = None
8991
self.device = None
9092
self.metadata = None
91-
self.benchmark = False
93+
self.printer = printer
94+
self.benchmark = benchmark
9295
self.dest_type = dest_type
9396
self.dest_dtype = dest_dtype
9497

@@ -101,7 +104,7 @@ def load(
101104
extra_plugin=None,
102105
):
103106
self.module_name = module_name
104-
print(
107+
self.printer.print(
105108
f"Loading {module_name} from {vmfb_path} with external weights: {external_weight_path}."
106109
)
107110
self.runner = vmfbRunner(
@@ -222,7 +225,9 @@ def _run_and_benchmark(self, function_name, inputs: list):
222225
start_time = time.time()
223226
output = self._run(function_name, inputs)
224227
latency = time.time() - start_time
225-
print(f"Latency for {self.module_name}['{function_name}']: {latency}sec")
228+
self.printer.print(
229+
f"Latency for {self.module_name}['{function_name}']: {latency}sec"
230+
)
226231
return output
227232

228233
def __call__(self, function_name, inputs: list):
@@ -238,6 +243,41 @@ def __call__(self, function_name, inputs: list):
238243
return output
239244

240245

246+
class Printer:
247+
def __init__(self, verbose, start_time, print_time):
248+
"""
249+
verbose: 0 for silence, 1 for print
250+
start_time: time of construction (or reset) of this Printer
251+
last_print: time of last call to 'print' method
252+
print_time: 1 to print with time prefix, 0 to not
253+
"""
254+
self.verbose = verbose
255+
self.start_time = start_time
256+
self.last_print = start_time
257+
self.print_time = print_time
258+
259+
def reset(self):
260+
if self.print_time:
261+
if self.verbose:
262+
self.print("Will now reset clock for printer to 0.0 [s].")
263+
self.last_print = time.time()
264+
self.start_time = time.time()
265+
if self.verbose:
266+
self.print("Clock for printer reset to t = 0.0 [s].")
267+
268+
def print(self, message):
269+
if self.verbose:
270+
# Print something like "[t=0.123 dt=0.004] 'message'"
271+
if self.print_time:
272+
time_now = time.time()
273+
print(
274+
f"[t={time_now - self.start_time:.3f} dt={time_now - self.last_print:.3f}] {message}"
275+
)
276+
self.last_print = time_now
277+
else:
278+
print(f"{message}")
279+
280+
241281
class TurbinePipelineBase:
242282
"""
243283
This class is a lightweight base for Stable Diffusion
@@ -298,9 +338,13 @@ def __init__(
298338
pipeline_dir: str = "./shark_vmfbs",
299339
external_weights_dir: str = "./shark_weights",
300340
hf_model_name: str | dict[str] = None,
341+
benchmark: bool | dict[bool] = False,
342+
verbose: bool = False,
301343
common_export_args: dict = {},
302344
):
303345
self.map = model_map
346+
self.verbose = verbose
347+
self.printer = Printer(self.verbose, time.time(), True)
304348
if isinstance(device, dict):
305349
assert isinstance(
306350
target, dict
@@ -329,6 +373,7 @@ def __init__(
329373
"decomp_attn": decomp_attn,
330374
"external_weights": external_weights,
331375
"hf_model_name": hf_model_name,
376+
"benchmark": benchmark,
332377
}
333378
for arg in map_arguments.keys():
334379
self.map = merge_arg_into_map(self.map, map_arguments[arg], arg)
@@ -396,7 +441,7 @@ def prepare_all(
396441
ready = self.is_prepared(vmfbs, weights)
397442
match ready:
398443
case True:
399-
print("All necessary files found.")
444+
self.printer.print("All necessary files found.")
400445
return
401446
case False:
402447
if interactive:
@@ -407,7 +452,7 @@ def prepare_all(
407452
exit()
408453
for submodel in self.map.keys():
409454
if not self.map[submodel].get("vmfb"):
410-
print("Fetching: ", submodel)
455+
self.printer.print("Fetching: ", submodel)
411456
self.export_submodel(
412457
submodel, input_mlir=self.map[submodel].get("mlir")
413458
)
@@ -456,8 +501,6 @@ def is_prepared(self, vmfbs, weights):
456501
mlir_keywords.remove(kw)
457502
avail_files = os.listdir(pipeline_dir)
458503
candidates = []
459-
# print("MLIR KEYS: ", mlir_keywords)
460-
# print("AVAILABLE FILES: ", avail_files)
461504
for filename in avail_files:
462505
if all(str(x) in filename for x in keywords) and not any(
463506
x in filename for x in neg_keywords
@@ -470,8 +513,8 @@ def is_prepared(self, vmfbs, weights):
470513
if len(candidates) == 1:
471514
self.map[key]["vmfb"] = candidates[0]
472515
elif len(candidates) > 1:
473-
print(f"Multiple files found for {key}: {candidates}")
474-
print(f"Choosing {candidates[0]} for {key}.")
516+
self.printer.print(f"Multiple files found for {key}: {candidates}")
517+
self.printer.print(f"Choosing {candidates[0]} for {key}.")
475518
self.map[key]["vmfb"] = candidates[0]
476519
else:
477520
# vmfb not found in pipeline_dir. Add to list of files to generate.
@@ -503,16 +546,18 @@ def is_prepared(self, vmfbs, weights):
503546
if len(candidates) == 1:
504547
self.map[key]["weights"] = candidates[0]
505548
elif len(candidates) > 1:
506-
print(f"Multiple weight files found for {key}: {candidates}")
507-
print(f"Choosing {candidates[0]} for {key}.")
549+
self.printer.print(
550+
f"Multiple weight files found for {key}: {candidates}"
551+
)
552+
self.printer.print(f"Choosing {candidates[0]} for {key}.")
508553
self.map[key][weights] = candidates[0]
509554
elif self.map[key].get("external_weights"):
510555
# weights not found in external_weights_dir. Add to list of files to generate.
511556
missing[key].append("weights")
512557
if not any(x for x in missing.values()):
513558
ready = True
514559
else:
515-
print("Missing files: ", missing)
560+
self.printer.print("Missing files: ", missing)
516561
ready = False
517562
return ready
518563

@@ -678,7 +723,7 @@ def export_submodel(
678723
def load_map(self):
679724
for submodel in self.map.keys():
680725
if not self.map[submodel]["load"]:
681-
print("Skipping load for ", submodel)
726+
self.printer.print("Skipping load for ", submodel)
682727
continue
683728
self.load_submodel(submodel)
684729

@@ -690,7 +735,11 @@ def load_submodel(self, submodel):
690735
):
691736
raise ValueError(f"Weights not found for {submodel}.")
692737
dest_type = self.map[submodel].get("dest_type", "devicearray")
693-
self.map[submodel]["runner"] = PipelineComponent(dest_type=dest_type)
738+
self.map[submodel]["runner"] = PipelineComponent(
739+
printer=self.printer,
740+
dest_type=dest_type,
741+
benchmark=self.map[submodel].get("benchmark", False),
742+
)
694743
self.map[submodel]["runner"].load(
695744
self.map[submodel]["driver"],
696745
self.map[submodel]["vmfb"],

models/turbine_models/custom_models/sd_inference/sd_cmd_opts.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,13 @@ def is_valid_file(arg):
144144
help="Run scheduling on native pytorch CPU backend.",
145145
)
146146

147+
p.add_argument(
148+
"--benchmark",
149+
type=str,
150+
default=None,
151+
help="A comma-separated list of submodel IDs for which to report benchmarks for, or 'all' for all components.",
152+
)
153+
147154
##############################################################################
148155
# SDXL Modelling Options
149156
# These options are used to control model defining parameters for SDXL.
@@ -198,6 +205,7 @@ def is_valid_file(arg):
198205

199206
p.add_argument("--compile_to", type=str, default="mlir", help="torch, linalg, vmfb")
200207

208+
p.add_argument("--verbose", "-v", action="store_true")
201209
p.add_argument(
202210
"--external_weights",
203211
type=str,

models/turbine_models/custom_models/sd_inference/sd_pipeline.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,9 @@ def __init__(
230230
scheduler_id: str = None, # compatibility only
231231
shift: float = 1.0, # compatibility only
232232
use_i8_punet: bool = False,
233+
benchmark: bool | dict[bool] = False,
234+
verbose: bool = False,
235+
batch_prompts: bool = False,
233236
):
234237
common_export_args = {
235238
"hf_model_name": None,
@@ -276,6 +279,8 @@ def __init__(
276279
pipeline_dir,
277280
external_weights_dir,
278281
hf_model_name,
282+
benchmark,
283+
verbose,
279284
common_export_args,
280285
)
281286
for submodel in sd_model_map:
@@ -329,6 +334,7 @@ def __init__(
329334
self.base_model_name, subfolder="tokenizer_2"
330335
),
331336
]
337+
self.map["text_encoder"]["export_args"]["batch_input"] = batch_prompts
332338
self.latents_precision = self.map["unet"]["precision"]
333339
self.scheduler_device = self.map["unet"]["device"]
334340
self.scheduler_driver = self.map["unet"]["driver"]
@@ -559,7 +565,10 @@ def _produce_latents_sdxl(
559565
[guidance_scale],
560566
dtype=self.map["unet"]["np_dtype"],
561567
)
562-
for i, t in tqdm(enumerate(timesteps)):
568+
for i, t in tqdm(
569+
enumerate(timesteps),
570+
disable=(self.map["unet"].get("benchmark") and self.verbose),
571+
):
563572
if self.cpu_scheduling:
564573
latent_model_input, t = self.scheduler.scale_model_input(
565574
latents,
@@ -571,7 +580,6 @@ def _produce_latents_sdxl(
571580
latent_model_input, t = self.scheduler(
572581
"run_scale", [latents, step, timesteps]
573582
)
574-
575583
unet_inputs = [
576584
latent_model_input,
577585
t,
@@ -703,6 +711,15 @@ def numpy_to_pil_image(images):
703711
}
704712
if not args.pipeline_dir:
705713
args.pipeline_dir = utils.create_safe_name(args.hf_model_name, "")
714+
benchmark = {}
715+
if args.benchmark:
716+
if args.benchmark.lower() == "all":
717+
benchmark = True
718+
else:
719+
for i in args.benchmark.split(","):
720+
benchmark[i] = True
721+
else:
722+
benchmark = False
706723
if any(x for x in [args.vae_decomp_attn, args.unet_decomp_attn]):
707724
args.decomp_attn = {
708725
"text_encoder": args.decomp_attn,
@@ -731,6 +748,8 @@ def numpy_to_pil_image(images):
731748
args.scheduler_id,
732749
None,
733750
args.use_i8_punet,
751+
benchmark,
752+
args.verbose,
734753
)
735754
sd_pipe.prepare_all()
736755
sd_pipe.load_map()

0 commit comments

Comments
 (0)