99import time
1010import numpy as np
1111import random
12+ import platform
1213
13- from . import utils
14+ from graph_net .paddle import utils
15+ from graph_net .benchmark_result import BenchmarkResult
1416
1517
1618def load_class_from_file (file_path : str , class_name : str ):
@@ -201,6 +203,27 @@ def measure_performance(model_call, synchronizer_func, args, profile=False):
201203 return outs , times
202204
203205
206+ def init_benchmark_result (args ):
207+ if args .device == "cuda" :
208+ hardware = paddle .device .cuda .get_device_name (0 )
209+ elif args .device == "cpu" :
210+ hardware = platform .processor ()
211+ else :
212+ hardware = "unknown"
213+
214+ if args .compiler == "CINN" :
215+ compile_framework_version = paddle .__version__
216+ else :
217+ compile_framework_version = "unknown"
218+
219+ result_data = BenchmarkResult (
220+ args = args ,
221+ hardware = hardware ,
222+ compile_framework_version = compile_framework_version ,
223+ )
224+ return result_data
225+
226+
204227def test_single_model (args ):
205228 synchronizer_func = get_synchronizer_func (args )
206229 input_dict , input_dtypes , param_dtypes = get_input_dict (args )
@@ -210,12 +233,16 @@ def test_single_model(args):
210233 # Collect model information
211234 num_ops = count_number_of_ops (args , model )
212235
213- print ("Run on eager mode" )
236+ # Initialize benchmark result
237+ result_data = init_benchmark_result (args )
238+ result_data .update_model_info (num_ops , input_dtypes , param_dtypes )
239+
240+ # Run on eager mode
214241 expected_out , eager_time_ms = measure_performance (
215242 lambda : model (** input_dict ), synchronizer_func , args , profile = False
216243 )
217244
218- print ( " Run on compiling mode" )
245+ # Run on compiling mode
219246 compiled_model = get_compiled_model (args , model )
220247 compiled_out , compiled_time_ms = measure_performance (
221248 lambda : compiled_model (** input_dict ), synchronizer_func , args , profile = False
@@ -243,6 +270,7 @@ def test_single_model(args):
243270
244271 def print_cmp (key , func , ** kwargs ):
245272 cmp_ret = func (expected_out , compiled_out , ** kwargs )
273+ result_data .update_corrrectness (key , cmp_ret )
246274 print (
247275 f"{ args .log_prompt } { key } model_path:{ args .model_path } { cmp_ret } " ,
248276 file = sys .stderr ,
@@ -271,6 +299,10 @@ def print_cmp(key, func, **kwargs):
271299 file = sys .stderr ,
272300 )
273301
302+ result_data .update_performance (eager_time_ms , compiled_time_ms )
303+ if args .output_dir :
304+ result_data .write_to_json (args .output_dir )
305+
274306
275307def get_cmp_equal (expected_out , compiled_out ):
276308 return " " .join (
@@ -372,6 +404,13 @@ def main(args):
372404 default = "CINN" ,
373405 help = "Path to customized compiler python file" ,
374406 )
407+ parser .add_argument (
408+ "--device" ,
409+ type = str ,
410+ required = False ,
411+ default = "cuda" ,
412+ help = "Device for testing the compiler (e.g., 'cpu' or 'cuda')" ,
413+ )
375414 parser .add_argument (
376415 "--warmup" , type = int , required = False , default = 5 , help = "Number of warmup steps"
377416 )
@@ -391,5 +430,12 @@ def main(args):
391430 default = "graph-net-test-compiler-log" ,
392431 help = "Log prompt for performance log filtering." ,
393432 )
433+ parser .add_argument (
434+ "--output-dir" ,
435+ type = str ,
436+ required = False ,
437+ default = None ,
438+ help = "Directory to save the structured JSON result file." ,
439+ )
394440 args = parser .parse_args ()
395441 main (args = args )
0 commit comments