1717from graph_net import path_utils
1818from graph_net import test_compiler_util
1919
20+ from graph_net .paddle .backend .graph_compiler_backend import GraphCompilerBackend
21+ from graph_net .paddle .backend .cinn_backend import CinnBackend
22+ from graph_net .paddle .backend .nope_backend import NopeBackend
23+
24+
25+ registry_backend = {
26+ "cinn" : CinnBackend (),
27+ "nope" : NopeBackend (),
28+ }
29+
30+
31+ def get_compiler_backend (args ) -> GraphCompilerBackend :
32+ assert args .compiler in registry_backend , f"Unknown compiler: { args .compiler } "
33+ return registry_backend [args .compiler ]
34+
2035
2136def set_seed (random_seed ):
2237 paddle .seed (random_seed )
@@ -60,10 +75,6 @@ def load_class_from_file(file_path: str, class_name: str):
6075 return model_class
6176
6277
63- def get_synchronizer_func (args ):
64- return paddle .device .synchronize
65-
66-
6778def get_model (model_path ):
6879 model_class = load_class_from_file (
6980 f"{ model_path } /model.py" , class_name = "GraphModule"
@@ -91,22 +102,6 @@ def get_input_spec(model_path):
91102 return input_spec
92103
93104
94- def get_compiled_model (args , model ):
95- if args .compiler == "nope" :
96- return model
97- input_spec = get_input_spec (args .model_path )
98- build_strategy = paddle .static .BuildStrategy ()
99- compiled_model = paddle .jit .to_static (
100- model ,
101- input_spec = input_spec ,
102- build_strategy = build_strategy ,
103- full_graph = True ,
104- )
105- compiled_model .eval ()
106- program = compiled_model .forward .concrete_program .main_program
107- return compiled_model
108-
109-
110105def get_static_model (args , model ):
111106 static_model = paddle .jit .to_static (
112107 model ,
@@ -119,7 +114,7 @@ def get_static_model(args, model):
119114 return static_model
120115
121116
122- def measure_performance (model_call , args , synchronizer_func , profile = False ):
117+ def measure_performance (model_call , args , compiler , profile = False ):
123118 runtime_seed = 1024
124119 stats = {}
125120
@@ -129,7 +124,7 @@ def measure_performance(model_call, args, synchronizer_func, profile=False):
129124 # Warmup runs
130125 for _ in range (args .warmup ):
131126 model_call ()
132- synchronizer_func ()
127+ compiler . synchronize ()
133128
134129 hardware_name = get_hardward_name (args )
135130 print (
@@ -152,7 +147,7 @@ def measure_performance(model_call, args, synchronizer_func, profile=False):
152147 for i in range (args .trials ):
153148 # End-to-end timing (naive_timer)
154149 duration_box = test_compiler_util .DurationBox (- 1 )
155- with test_compiler_util .naive_timer (duration_box , synchronizer_func ):
150+ with test_compiler_util .naive_timer (duration_box , compiler . synchronize ):
156151 # GPU-only timing (CUDA Events)
157152 start_event = paddle .device .Event (enable_timing = True )
158153 end_event = paddle .device .Event (enable_timing = True )
@@ -178,7 +173,7 @@ def measure_performance(model_call, args, synchronizer_func, profile=False):
178173 e2e_times = []
179174 for i in range (args .trials ):
180175 duration_box = test_compiler_util .DurationBox (- 1 )
181- with test_compiler_util .naive_timer (duration_box , synchronizer_func ):
176+ with test_compiler_util .naive_timer (duration_box , compiler . synchronize ):
182177 model_call ()
183178 print (
184179 f"Trial { i + 1 } : e2e={ duration_box .value :.4f} ms" ,
@@ -247,8 +242,25 @@ def transfer_to_float(origin_outputs):
247242 )
248243
249244
245+ def check_and_print_gpu_utilization (compiler ):
246+ if paddle .device .is_compiled_with_cuda ():
247+ device_id = int (paddle .device .get_device ().split (":" )[- 1 ])
248+ device_count = paddle .device .cuda .device_count ()
249+ gpu_util , mem_util = test_compiler_util .get_device_utilization (
250+ device_id , device_count , compiler .synchronize
251+ )
252+ if gpu_util is not None and mem_util is not None :
253+ print (
254+ f"Device status: gpu_id { device_id } , gpu_util { gpu_util :.2f} %, mem_util { mem_util :.2f} %" ,
255+ file = sys .stderr ,
256+ flush = True ,
257+ )
258+
259+
250260def test_single_model (args ):
251- synchronizer_func = get_synchronizer_func (args )
261+ compiler = get_compiler_backend (args )
262+ check_and_print_gpu_utilization (compiler )
263+
252264 input_dict = get_input_dict (args .model_path )
253265 model = get_model (args .model_path )
254266 model .eval ()
@@ -264,7 +276,7 @@ def test_single_model(args):
264276 print ("Run model in eager mode." , file = sys .stderr , flush = True )
265277 static_model = get_static_model (args , model )
266278 expected_out , eager_time_stats = measure_performance (
267- lambda : static_model (** input_dict ), args , synchronizer_func , profile = False
279+ lambda : static_model (** input_dict ), args , compiler , profile = False
268280 )
269281 eager_success = True
270282 except Exception as e :
@@ -279,9 +291,10 @@ def test_single_model(args):
279291 compiled_time_stats = {}
280292 try :
281293 print ("Run model in compiled mode." , file = sys .stderr , flush = True )
282- compiled_model = get_compiled_model (args , model )
294+ input_spec = get_input_spec (args .model_path )
295+ compiled_model = compiler (model , input_spec )
283296 compiled_out , compiled_time_stats = measure_performance (
284- lambda : compiled_model (** input_dict ), args , synchronizer_func , profile = False
297+ lambda : compiled_model (** input_dict ), args , compiler , profile = False
285298 )
286299 compiled_success = True
287300 except Exception as e :
@@ -415,18 +428,6 @@ def main(args):
415428 set_seed (random_seed = initalize_seed )
416429
417430 if path_utils .is_single_model_dir (args .model_path ):
418- if paddle .device .is_compiled_with_cuda ():
419- device_id = int (paddle .device .get_device ().split (":" )[- 1 ])
420- device_count = paddle .device .cuda .device_count ()
421- gpu_util , mem_util = test_compiler_util .get_device_utilization (
422- device_id , device_count , get_synchronizer_func (args )
423- )
424- if gpu_util is not None and mem_util is not None :
425- print (
426- f"Device status: gpu_id { device_id } , gpu_util { gpu_util :.2f} %, mem_util { mem_util :.2f} %" ,
427- file = sys .stderr ,
428- flush = True ,
429- )
430431 test_single_model (args )
431432 else :
432433 test_multi_models (args )
0 commit comments