Skip to content

Commit eda75dd

Browse files
enhance the perf value update
1 parent e7a3cbf commit eda75dd

File tree

1 file changed

+107
-39
lines changed

1 file changed

+107
-39
lines changed

.github/scripts/op_calculate_best_perf.py

Lines changed: 107 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,48 @@
1414
updated_cases = []
1515
removed_cases = []
1616

17+
def safe_float_convert(value):
18+
try:
19+
return float(value) if value.strip() else None
20+
except (ValueError, AttributeError):
21+
return None
22+
1723
def update_baseline(xpu_file, baseline_file, remove_missing=False):
1824
with open(xpu_file) as f:
1925
xpu_reader = csv.DictReader(f, delimiter=';')
2026
xpu_rows = list(xpu_reader)
21-
xpu_fieldnames = xpu_reader.fieldnames # Keep original field order
22-
fieldnames = [f for f in xpu_fieldnames if f not in ['time(us)', 'E2E total time(us)', 'E2E forward time(us)']]
23-
xpu_data = {make_key(row, fieldnames): (float(row['time(us)']), row) for row in xpu_rows}
27+
xpu_fieldnames = xpu_reader.fieldnames
28+
time_fields = ['time(us)', 'E2E total time(us)', 'E2E forward time(us)']
29+
fieldnames = [f for f in xpu_fieldnames if f not in time_fields]
30+
xpu_data = {}
31+
for row in xpu_rows:
32+
key = make_key(row, fieldnames)
33+
time_values = {}
34+
if 'time(us)' in row:
35+
time_val = safe_float_convert(row['time(us)'])
36+
if time_val is not None:
37+
time_values['time(us)'] = time_val
38+
if 'E2E total time(us)' in row:
39+
e2e_val = safe_float_convert(row['E2E total time(us)'])
40+
if e2e_val is not None:
41+
time_values['E2E total time(us)'] = e2e_val
42+
xpu_data[key] = (time_values, row)
2443

2544
with open(baseline_file) as f:
2645
baseline_reader = csv.DictReader(f, delimiter=';')
2746
baseline_rows = list(baseline_reader)
2847
baseline_fieldnames = baseline_reader.fieldnames
2948

3049
# To add new parameter of new ops into baseline file
31-
all_fieldnames = xpu_fieldnames + [f for f in baseline_fieldnames if f not in xpu_fieldnames]
32-
fieldnames = [f for f in all_fieldnames if f not in ['time(us)', 'E2E total time(us)', 'E2E forward time(us)']]
50+
all_fieldnames = list(set(xpu_fieldnames + baseline_fieldnames))
51+
# Maintain original order as much as possible
52+
ordered_fieldnames = []
53+
for f in xpu_fieldnames:
54+
if f in all_fieldnames and f not in ordered_fieldnames:
55+
ordered_fieldnames.append(f)
56+
for f in baseline_fieldnames:
57+
if f in all_fieldnames and f not in ordered_fieldnames:
58+
ordered_fieldnames.append(f)
3359

3460
baseline_keys = {make_key(row, fieldnames) for row in baseline_rows}
3561
xpu_keys = set(xpu_data.keys())
@@ -38,75 +64,117 @@ def update_baseline(xpu_file, baseline_file, remove_missing=False):
3864
for row in baseline_rows:
3965
key = make_key(row, fieldnames)
4066
if key in xpu_data:
41-
xpu_time, xpu_row = xpu_data[key]
42-
baseline_time = float(row['time(us)'])
43-
44-
if xpu_time < baseline_time:
45-
updated_row = {}
46-
for field in all_fieldnames:
47-
updated_row[field] = xpu_row.get(field, row.get(field, ''))
48-
updated_row['time(us)'] = str(xpu_time)
49-
if 'E2E total time(us)' in row:
50-
updated_row['E2E total time(us)'] = row['E2E total time(us)']
51-
updated_cases.append((key, baseline_time, xpu_time, updated_row))
52-
updated_rows.append(updated_row)
53-
else:
54-
ordered_row = {}
55-
for field in all_fieldnames:
56-
ordered_row[field] = row.get(field, '')
57-
updated_rows.append(ordered_row)
67+
xpu_times, xpu_row = xpu_data[key]
68+
updated_row = {}
69+
70+
# Copy all fields from baseline first
71+
for field in ordered_fieldnames:
72+
updated_row[field] = row.get(field, '')
73+
74+
# Update with xpu values where they exist
75+
for field in ordered_fieldnames:
76+
if field in xpu_row and xpu_row[field]:
77+
updated_row[field] = xpu_row[field]
78+
79+
# Handle time fields
80+
updated = False
81+
if 'time(us)' in xpu_times and 'time(us)' in row:
82+
baseline_time = safe_float_convert(row['time(us)'])
83+
if baseline_time is not None:
84+
xpu_time = xpu_times['time(us)']
85+
if xpu_time < baseline_time:
86+
updated_row['time(us)'] = str(xpu_time)
87+
updated = True
88+
89+
if 'E2E total time(us)' in xpu_times and 'E2E total time(us)' in row:
90+
baseline_e2e = safe_float_convert(row['E2E total time(us)'])
91+
if baseline_e2e is not None:
92+
xpu_e2e = xpu_times['E2E total time(us)']
93+
if xpu_e2e < baseline_e2e:
94+
updated_row['E2E total time(us)'] = str(xpu_e2e)
95+
updated = True
96+
97+
if updated:
98+
updated_cases.append((key, row, updated_row))
99+
updated_rows.append(updated_row)
58100
elif not remove_missing:
59101
ordered_row = {}
60-
for field in all_fieldnames:
102+
for field in ordered_fieldnames:
61103
ordered_row[field] = row.get(field, '')
62104
updated_rows.append(ordered_row)
63105

64106
# Add new cases
65107
for key in xpu_keys - baseline_keys:
66-
xpu_time, xpu_row = xpu_data[key]
108+
xpu_times, xpu_row = xpu_data[key]
67109
new_row = {}
68-
for field in all_fieldnames:
110+
for field in ordered_fieldnames:
69111
new_row[field] = xpu_row.get(field, '')
70-
new_row['time(us)'] = str(xpu_time)
112+
113+
if 'time(us)' in xpu_times:
114+
new_row['time(us)'] = str(xpu_times['time(us)'])
115+
if 'E2E total time(us)' in xpu_times:
116+
new_row['E2E total time(us)'] = str(xpu_times['E2E total time(us)'])
117+
71118
updated_rows.append(new_row)
72-
added_cases.append((key, xpu_time, new_row))
119+
added_cases.append((key, xpu_times, new_row))
73120

74121
# Resolve removed cases
75122
if remove_missing:
76123
for key in baseline_keys - xpu_keys:
77124
removed_case = next(row for row in baseline_rows if make_key(row, fieldnames) == key)
78-
removed_cases.append((key, float(removed_case['time(us)']), removed_case))
125+
removed_cases.append((key, removed_case))
79126

80127
if added_cases:
81128
print(f"\nAdded {len(added_cases)} new case(s):")
82-
for key, time, row in added_cases:
129+
for key, times, row in added_cases:
83130
print(f"\n[New Case] {format_case(key)}")
84-
print(f"Time: {time} us")
131+
if 'time(us)' in times:
132+
print(f"Time: {times['time(us)']} us")
133+
if 'E2E total time(us)' in times:
134+
print(f"E2E Time: {times['E2E total time(us)']} us")
85135
print("Parameters:")
86136
for k, v in row.items():
87-
if k not in ['time(us)', 'E2E total time(us)', 'E2E forward time(us)']:
137+
if k not in time_fields:
88138
print(f" {k}: {v}")
89139
print("-" * 60)
90140

91141
if updated_cases:
92142
print(f"\nUpdated {len(updated_cases)} case(s):")
93-
for key, old_time, new_time, row in updated_cases:
143+
for key, old_row, new_row in updated_cases:
94144
print(f"\n[Updated] {format_case(key)}")
95-
print(f"Time: {old_time} us → {new_time} us")
145+
if 'time(us)' in old_row and 'time(us)' in new_row:
146+
old_time = safe_float_convert(old_row['time(us)'])
147+
new_time = safe_float_convert(new_row['time(us)'])
148+
if old_time is not None and new_time is not None and old_time != new_time:
149+
print(f"Time: {old_time} us → {new_time} us")
150+
151+
if 'E2E total time(us)' in old_row and 'E2E total time(us)' in new_row:
152+
old_e2e = safe_float_convert(old_row['E2E total time(us)'])
153+
new_e2e = safe_float_convert(new_row['E2E total time(us)'])
154+
if old_e2e is not None and new_e2e is not None and old_e2e != new_e2e:
155+
print(f"E2E Time: {old_e2e} us → {new_e2e} us")
156+
96157
print("Parameters:")
97-
for k, v in row.items():
98-
if k not in ['time(us)', 'E2E total time(us)', 'E2E forward time(us)']:
158+
for k, v in new_row.items():
159+
if k not in time_fields:
99160
print(f" {k}: {v}")
100161
print("-" * 60)
101162

102163
if remove_missing and removed_cases:
103164
print(f"\nRemoved {len(removed_cases)} case(s):")
104-
for key, time, row in removed_cases:
165+
for key, row in removed_cases:
105166
print(f"\n[Removed] {format_case(key)}")
106-
print(f"Time: {time} us")
167+
if 'time(us)' in row:
168+
time_val = safe_float_convert(row['time(us)'])
169+
if time_val is not None:
170+
print(f"Time: {time_val} us")
171+
if 'E2E total time(us)' in row:
172+
e2e_val = safe_float_convert(row['E2E total time(us)'])
173+
if e2e_val is not None:
174+
print(f"E2E Time: {e2e_val} us")
107175
print("Parameters:")
108176
for k, v in row.items():
109-
if k not in ['time(us)', 'E2E total time(us)', 'E2E forward time(us)']:
177+
if k not in time_fields:
110178
print(f" {k}: {v}")
111179
print("-" * 60)
112180

@@ -117,7 +185,7 @@ def update_baseline(xpu_file, baseline_file, remove_missing=False):
117185
Path(baseline_file).rename(backup_file)
118186

119187
with open(baseline_file, 'w', newline='') as f:
120-
writer = csv.DictWriter(f, fieldnames=all_fieldnames, delimiter=';')
188+
writer = csv.DictWriter(f, fieldnames=ordered_fieldnames, delimiter=';')
121189
writer.writeheader()
122190
writer.writerows(updated_rows)
123191

0 commit comments

Comments
 (0)