@@ -83,12 +83,15 @@ class PipelineComponent:
8383 This aims to make new pipelines and execution modes easier to write, manage, and debug.
8484 """
8585
86- def __init__ (self , dest_type = "devicearray" , dest_dtype = "float16" ):
86+ def __init__ (
87+ self , printer , dest_type = "devicearray" , dest_dtype = "float16" , benchmark = False
88+ ):
8789 self .runner = None
8890 self .module_name = None
8991 self .device = None
9092 self .metadata = None
91- self .benchmark = False
93+ self .printer = printer
94+ self .benchmark = benchmark
9295 self .dest_type = dest_type
9396 self .dest_dtype = dest_dtype
9497
@@ -101,7 +104,7 @@ def load(
101104 extra_plugin = None ,
102105 ):
103106 self .module_name = module_name
104- print (
107+ self . printer . print (
105108 f"Loading { module_name } from { vmfb_path } with external weights: { external_weight_path } ."
106109 )
107110 self .runner = vmfbRunner (
@@ -222,7 +225,9 @@ def _run_and_benchmark(self, function_name, inputs: list):
222225 start_time = time .time ()
223226 output = self ._run (function_name , inputs )
224227 latency = time .time () - start_time
225- print (f"Latency for { self .module_name } ['{ function_name } ']: { latency } sec" )
228+ self .printer .print (
229+ f"Latency for { self .module_name } ['{ function_name } ']: { latency } sec"
230+ )
226231 return output
227232
228233 def __call__ (self , function_name , inputs : list ):
@@ -238,6 +243,41 @@ def __call__(self, function_name, inputs: list):
238243 return output
239244
240245
246+ class Printer :
247+ def __init__ (self , verbose , start_time , print_time ):
248+ """
249+ verbose: 0 for silence, 1 for print
250+ start_time: time of construction (or reset) of this Printer
251+ last_print: time of last call to 'print' method
252+ print_time: 1 to print with time prefix, 0 to not
253+ """
254+ self .verbose = verbose
255+ self .start_time = start_time
256+ self .last_print = start_time
257+ self .print_time = print_time
258+
259+ def reset (self ):
260+ if self .print_time :
261+ if self .verbose :
262+ self .print ("Will now reset clock for printer to 0.0 [s]." )
263+ self .last_print = time .time ()
264+ self .start_time = time .time ()
265+ if self .verbose :
266+ self .print ("Clock for printer reset to t = 0.0 [s]." )
267+
268+ def print (self , message ):
269+ if self .verbose :
270+ # Print something like "[t=0.123 dt=0.004] 'message'"
271+ if self .print_time :
272+ time_now = time .time ()
273+ print (
274+ f"[t={ time_now - self .start_time :.3f} dt={ time_now - self .last_print :.3f} ] { message } "
275+ )
276+ self .last_print = time_now
277+ else :
278+ print (f"{ message } " )
279+
280+
241281class TurbinePipelineBase :
242282 """
243283 This class is a lightweight base for Stable Diffusion
@@ -298,9 +338,13 @@ def __init__(
298338 pipeline_dir : str = "./shark_vmfbs" ,
299339 external_weights_dir : str = "./shark_weights" ,
300340 hf_model_name : str | dict [str ] = None ,
341+ benchmark : bool | dict [bool ] = False ,
342+ verbose : bool = False ,
301343 common_export_args : dict = {},
302344 ):
303345 self .map = model_map
346+ self .verbose = verbose
347+ self .printer = Printer (self .verbose , time .time (), True )
304348 if isinstance (device , dict ):
305349 assert isinstance (
306350 target , dict
@@ -329,6 +373,7 @@ def __init__(
329373 "decomp_attn" : decomp_attn ,
330374 "external_weights" : external_weights ,
331375 "hf_model_name" : hf_model_name ,
376+ "benchmark" : benchmark ,
332377 }
333378 for arg in map_arguments .keys ():
334379 self .map = merge_arg_into_map (self .map , map_arguments [arg ], arg )
@@ -396,7 +441,7 @@ def prepare_all(
396441 ready = self .is_prepared (vmfbs , weights )
397442 match ready :
398443 case True :
399- print ("All necessary files found." )
444+ self . printer . print ("All necessary files found." )
400445 return
401446 case False :
402447 if interactive :
@@ -407,7 +452,7 @@ def prepare_all(
407452 exit ()
408453 for submodel in self .map .keys ():
409454 if not self .map [submodel ].get ("vmfb" ):
410- print ("Fetching: " , submodel )
455+ self . printer . print ("Fetching: " , submodel )
411456 self .export_submodel (
412457 submodel , input_mlir = self .map [submodel ].get ("mlir" )
413458 )
@@ -456,8 +501,6 @@ def is_prepared(self, vmfbs, weights):
456501 mlir_keywords .remove (kw )
457502 avail_files = os .listdir (pipeline_dir )
458503 candidates = []
459- # print("MLIR KEYS: ", mlir_keywords)
460- # print("AVAILABLE FILES: ", avail_files)
461504 for filename in avail_files :
462505 if all (str (x ) in filename for x in keywords ) and not any (
463506 x in filename for x in neg_keywords
@@ -470,8 +513,8 @@ def is_prepared(self, vmfbs, weights):
470513 if len (candidates ) == 1 :
471514 self .map [key ]["vmfb" ] = candidates [0 ]
472515 elif len (candidates ) > 1 :
473- print (f"Multiple files found for { key } : { candidates } " )
474- print (f"Choosing { candidates [0 ]} for { key } ." )
516+ self . printer . print (f"Multiple files found for { key } : { candidates } " )
517+ self . printer . print (f"Choosing { candidates [0 ]} for { key } ." )
475518 self .map [key ]["vmfb" ] = candidates [0 ]
476519 else :
477520 # vmfb not found in pipeline_dir. Add to list of files to generate.
@@ -503,16 +546,18 @@ def is_prepared(self, vmfbs, weights):
503546 if len (candidates ) == 1 :
504547 self .map [key ]["weights" ] = candidates [0 ]
505548 elif len (candidates ) > 1 :
506- print (f"Multiple weight files found for { key } : { candidates } " )
507- print (f"Choosing { candidates [0 ]} for { key } ." )
549+ self .printer .print (
550+ f"Multiple weight files found for { key } : { candidates } "
551+ )
552+ self .printer .print (f"Choosing { candidates [0 ]} for { key } ." )
508553 self .map [key ][weights ] = candidates [0 ]
509554 elif self .map [key ].get ("external_weights" ):
510555 # weights not found in external_weights_dir. Add to list of files to generate.
511556 missing [key ].append ("weights" )
512557 if not any (x for x in missing .values ()):
513558 ready = True
514559 else :
515- print ("Missing files: " , missing )
560+ self . printer . print ("Missing files: " , missing )
516561 ready = False
517562 return ready
518563
@@ -678,7 +723,7 @@ def export_submodel(
678723 def load_map (self ):
679724 for submodel in self .map .keys ():
680725 if not self .map [submodel ]["load" ]:
681- print ("Skipping load for " , submodel )
726+ self . printer . print ("Skipping load for " , submodel )
682727 continue
683728 self .load_submodel (submodel )
684729
@@ -690,7 +735,11 @@ def load_submodel(self, submodel):
690735 ):
691736 raise ValueError (f"Weights not found for { submodel } ." )
692737 dest_type = self .map [submodel ].get ("dest_type" , "devicearray" )
693- self .map [submodel ]["runner" ] = PipelineComponent (dest_type = dest_type )
738+ self .map [submodel ]["runner" ] = PipelineComponent (
739+ printer = self .printer ,
740+ dest_type = dest_type ,
741+ benchmark = self .map [submodel ].get ("benchmark" , False ),
742+ )
694743 self .map [submodel ]["runner" ].load (
695744 self .map [submodel ]["driver" ],
696745 self .map [submodel ]["vmfb" ],
0 commit comments