14
14
updated_cases = []
15
15
removed_cases = []
16
16
17
+ def safe_float_convert (value ):
18
+ try :
19
+ return float (value ) if value .strip () else None
20
+ except (ValueError , AttributeError ):
21
+ return None
22
+
17
23
def update_baseline (xpu_file , baseline_file , remove_missing = False ):
18
24
with open (xpu_file ) as f :
19
25
xpu_reader = csv .DictReader (f , delimiter = ';' )
20
26
xpu_rows = list (xpu_reader )
21
- xpu_fieldnames = xpu_reader .fieldnames # Keep original field order
22
- fieldnames = [f for f in xpu_fieldnames if f not in ['time(us)' , 'E2E total time(us)' , 'E2E forward time(us)' ]]
23
- xpu_data = {make_key (row , fieldnames ): (float (row ['time(us)' ]), row ) for row in xpu_rows }
27
+ xpu_fieldnames = xpu_reader .fieldnames
28
+ time_fields = ['time(us)' , 'E2E total time(us)' , 'E2E forward time(us)' ]
29
+ fieldnames = [f for f in xpu_fieldnames if f not in time_fields ]
30
+ xpu_data = {}
31
+ for row in xpu_rows :
32
+ key = make_key (row , fieldnames )
33
+ time_values = {}
34
+ if 'time(us)' in row :
35
+ time_val = safe_float_convert (row ['time(us)' ])
36
+ if time_val is not None :
37
+ time_values ['time(us)' ] = time_val
38
+ if 'E2E total time(us)' in row :
39
+ e2e_val = safe_float_convert (row ['E2E total time(us)' ])
40
+ if e2e_val is not None :
41
+ time_values ['E2E total time(us)' ] = e2e_val
42
+ xpu_data [key ] = (time_values , row )
24
43
25
44
with open (baseline_file ) as f :
26
45
baseline_reader = csv .DictReader (f , delimiter = ';' )
27
46
baseline_rows = list (baseline_reader )
28
47
baseline_fieldnames = baseline_reader .fieldnames
29
48
30
49
# To add new parameter of new ops into baseline file
31
- all_fieldnames = xpu_fieldnames + [f for f in baseline_fieldnames if f not in xpu_fieldnames ]
32
- fieldnames = [f for f in all_fieldnames if f not in ['time(us)' , 'E2E total time(us)' , 'E2E forward time(us)' ]]
50
+ all_fieldnames = list (set (xpu_fieldnames + baseline_fieldnames ))
51
+ # Maintain original order as much as possible
52
+ ordered_fieldnames = []
53
+ for f in xpu_fieldnames :
54
+ if f in all_fieldnames and f not in ordered_fieldnames :
55
+ ordered_fieldnames .append (f )
56
+ for f in baseline_fieldnames :
57
+ if f in all_fieldnames and f not in ordered_fieldnames :
58
+ ordered_fieldnames .append (f )
33
59
34
60
baseline_keys = {make_key (row , fieldnames ) for row in baseline_rows }
35
61
xpu_keys = set (xpu_data .keys ())
@@ -38,75 +64,117 @@ def update_baseline(xpu_file, baseline_file, remove_missing=False):
38
64
for row in baseline_rows :
39
65
key = make_key (row , fieldnames )
40
66
if key in xpu_data :
41
- xpu_time , xpu_row = xpu_data [key ]
42
- baseline_time = float (row ['time(us)' ])
43
-
44
- if xpu_time < baseline_time :
45
- updated_row = {}
46
- for field in all_fieldnames :
47
- updated_row [field ] = xpu_row .get (field , row .get (field , '' ))
48
- updated_row ['time(us)' ] = str (xpu_time )
49
- if 'E2E total time(us)' in row :
50
- updated_row ['E2E total time(us)' ] = row ['E2E total time(us)' ]
51
- updated_cases .append ((key , baseline_time , xpu_time , updated_row ))
52
- updated_rows .append (updated_row )
53
- else :
54
- ordered_row = {}
55
- for field in all_fieldnames :
56
- ordered_row [field ] = row .get (field , '' )
57
- updated_rows .append (ordered_row )
67
+ xpu_times , xpu_row = xpu_data [key ]
68
+ updated_row = {}
69
+
70
+ # Copy all fields from baseline first
71
+ for field in ordered_fieldnames :
72
+ updated_row [field ] = row .get (field , '' )
73
+
74
+ # Update with xpu values where they exist
75
+ for field in ordered_fieldnames :
76
+ if field in xpu_row and xpu_row [field ]:
77
+ updated_row [field ] = xpu_row [field ]
78
+
79
+ # Handle time fields
80
+ updated = False
81
+ if 'time(us)' in xpu_times and 'time(us)' in row :
82
+ baseline_time = safe_float_convert (row ['time(us)' ])
83
+ if baseline_time is not None :
84
+ xpu_time = xpu_times ['time(us)' ]
85
+ if xpu_time < baseline_time :
86
+ updated_row ['time(us)' ] = str (xpu_time )
87
+ updated = True
88
+
89
+ if 'E2E total time(us)' in xpu_times and 'E2E total time(us)' in row :
90
+ baseline_e2e = safe_float_convert (row ['E2E total time(us)' ])
91
+ if baseline_e2e is not None :
92
+ xpu_e2e = xpu_times ['E2E total time(us)' ]
93
+ if xpu_e2e < baseline_e2e :
94
+ updated_row ['E2E total time(us)' ] = str (xpu_e2e )
95
+ updated = True
96
+
97
+ if updated :
98
+ updated_cases .append ((key , row , updated_row ))
99
+ updated_rows .append (updated_row )
58
100
elif not remove_missing :
59
101
ordered_row = {}
60
- for field in all_fieldnames :
102
+ for field in ordered_fieldnames :
61
103
ordered_row [field ] = row .get (field , '' )
62
104
updated_rows .append (ordered_row )
63
105
64
106
# Add new cases
65
107
for key in xpu_keys - baseline_keys :
66
- xpu_time , xpu_row = xpu_data [key ]
108
+ xpu_times , xpu_row = xpu_data [key ]
67
109
new_row = {}
68
- for field in all_fieldnames :
110
+ for field in ordered_fieldnames :
69
111
new_row [field ] = xpu_row .get (field , '' )
70
- new_row ['time(us)' ] = str (xpu_time )
112
+
113
+ if 'time(us)' in xpu_times :
114
+ new_row ['time(us)' ] = str (xpu_times ['time(us)' ])
115
+ if 'E2E total time(us)' in xpu_times :
116
+ new_row ['E2E total time(us)' ] = str (xpu_times ['E2E total time(us)' ])
117
+
71
118
updated_rows .append (new_row )
72
- added_cases .append ((key , xpu_time , new_row ))
119
+ added_cases .append ((key , xpu_times , new_row ))
73
120
74
121
# Resolve removed cases
75
122
if remove_missing :
76
123
for key in baseline_keys - xpu_keys :
77
124
removed_case = next (row for row in baseline_rows if make_key (row , fieldnames ) == key )
78
- removed_cases .append ((key , float ( removed_case [ 'time(us)' ]), removed_case ))
125
+ removed_cases .append ((key , removed_case ))
79
126
80
127
if added_cases :
81
128
print (f"\n Added { len (added_cases )} new case(s):" )
82
- for key , time , row in added_cases :
129
+ for key , times , row in added_cases :
83
130
print (f"\n [New Case] { format_case (key )} " )
84
- print (f"Time: { time } us" )
131
+ if 'time(us)' in times :
132
+ print (f"Time: { times ['time(us)' ]} us" )
133
+ if 'E2E total time(us)' in times :
134
+ print (f"E2E Time: { times ['E2E total time(us)' ]} us" )
85
135
print ("Parameters:" )
86
136
for k , v in row .items ():
87
- if k not in [ 'time(us)' , 'E2E total time(us)' , 'E2E forward time(us)' ] :
137
+ if k not in time_fields :
88
138
print (f" { k } : { v } " )
89
139
print ("-" * 60 )
90
140
91
141
if updated_cases :
92
142
print (f"\n Updated { len (updated_cases )} case(s):" )
93
- for key , old_time , new_time , row in updated_cases :
143
+ for key , old_row , new_row in updated_cases :
94
144
print (f"\n [Updated] { format_case (key )} " )
95
- print (f"Time: { old_time } us → { new_time } us" )
145
+ if 'time(us)' in old_row and 'time(us)' in new_row :
146
+ old_time = safe_float_convert (old_row ['time(us)' ])
147
+ new_time = safe_float_convert (new_row ['time(us)' ])
148
+ if old_time is not None and new_time is not None and old_time != new_time :
149
+ print (f"Time: { old_time } us → { new_time } us" )
150
+
151
+ if 'E2E total time(us)' in old_row and 'E2E total time(us)' in new_row :
152
+ old_e2e = safe_float_convert (old_row ['E2E total time(us)' ])
153
+ new_e2e = safe_float_convert (new_row ['E2E total time(us)' ])
154
+ if old_e2e is not None and new_e2e is not None and old_e2e != new_e2e :
155
+ print (f"E2E Time: { old_e2e } us → { new_e2e } us" )
156
+
96
157
print ("Parameters:" )
97
- for k , v in row .items ():
98
- if k not in [ 'time(us)' , 'E2E total time(us)' , 'E2E forward time(us)' ] :
158
+ for k , v in new_row .items ():
159
+ if k not in time_fields :
99
160
print (f" { k } : { v } " )
100
161
print ("-" * 60 )
101
162
102
163
if remove_missing and removed_cases :
103
164
print (f"\n Removed { len (removed_cases )} case(s):" )
104
- for key , time , row in removed_cases :
165
+ for key , row in removed_cases :
105
166
print (f"\n [Removed] { format_case (key )} " )
106
- print (f"Time: { time } us" )
167
+ if 'time(us)' in row :
168
+ time_val = safe_float_convert (row ['time(us)' ])
169
+ if time_val is not None :
170
+ print (f"Time: { time_val } us" )
171
+ if 'E2E total time(us)' in row :
172
+ e2e_val = safe_float_convert (row ['E2E total time(us)' ])
173
+ if e2e_val is not None :
174
+ print (f"E2E Time: { e2e_val } us" )
107
175
print ("Parameters:" )
108
176
for k , v in row .items ():
109
- if k not in [ 'time(us)' , 'E2E total time(us)' , 'E2E forward time(us)' ] :
177
+ if k not in time_fields :
110
178
print (f" { k } : { v } " )
111
179
print ("-" * 60 )
112
180
@@ -117,7 +185,7 @@ def update_baseline(xpu_file, baseline_file, remove_missing=False):
117
185
Path (baseline_file ).rename (backup_file )
118
186
119
187
with open (baseline_file , 'w' , newline = '' ) as f :
120
- writer = csv .DictWriter (f , fieldnames = all_fieldnames , delimiter = ';' )
188
+ writer = csv .DictWriter (f , fieldnames = ordered_fieldnames , delimiter = ';' )
121
189
writer .writeheader ()
122
190
writer .writerows (updated_rows )
123
191
0 commit comments