Skip to content

Commit 33cbc0d

Browse files
authored
[ Bug Fix ] Fix ESt plot (#317)
* Add support to test devices on Paddle * Update
1 parent c3c8599 commit 33cbc0d

File tree

1 file changed

+18
-22
lines changed

1 file changed

+18
-22
lines changed

graph_net/analysis_util.py

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,10 @@ def print_stat_info(
254254
# pi is a list of constants for t > 0 for each group
255255
pi = [0, 0]
256256

257+
is_correct_at_t1 = [False] * total_samples
258+
speedup_at_t1 = [None] * total_samples
259+
fail_type_at_t1 = ["CORRECT"] * total_samples
260+
257261
final_correct_count = 0
258262
final_correct_negative_speedup_count = 0
259263
final_correct_speedups = []
@@ -291,8 +295,8 @@ def print_stat_info(
291295
get_correctness(eager_dtypes[i], t_key, correctness_data, i)
292296
for i in range(output_count)
293297
)
294-
if not is_correct:
295-
fail_type = "accuracy"
298+
if not is_correct:
299+
fail_type = "accuracy"
296300

297301
# Collect statistics
298302
if is_correct:
@@ -306,6 +310,11 @@ def print_stat_info(
306310
if fail_type == "accuracy":
307311
acc_failure_count += 1
308312

313+
if t_key == 1:
314+
is_correct_at_t1[idx] = is_correct
315+
speedup_at_t1[idx] = speedup
316+
fail_type_at_t1[idx] = fail_type if fail_type is not None else "CORRECT"
317+
309318
# S(t) calculation
310319
if fail_type is not None or speedup is None:
311320
regularized_speedup = fpdb
@@ -320,37 +329,25 @@ def print_stat_info(
320329
# ES(t) calculation: based on state change
321330
rec_speedup_fake_degrad = 0
322331
if t_key < 1:
323-
# When t < 1, ES behaves the same as S
324332
if fail_type is not None or speedup is None:
325333
rec_speedup_fake_degrad = fpdb
326-
# print(f"sample: {sample.get('configuration').get('model')}, fail_type: {fail_type}, rec_speedup_fake_degrad: {rec_speedup_fake_degrad}")
327334
else:
328335
rec_speedup_fake_degrad = (
329336
speedup ** (negative_speedup_penalty + 1)
330337
if speedup < 1
331338
else speedup
332339
)
333340
else:
334-
# When t >= 1, ES starts applying stepwise logic
335-
# ES curve's stepwise state, initialized as 'CORRECT'
336-
es_status = ["CORRECT"] * total_samples
337-
if es_status[idx] == "CORRECT" and fail_type is not None:
338-
es_status[idx] = fail_type
339-
340-
if (
341-
es_status[idx] is not None
342-
and es_status[idx] != "CORRECT"
343-
or speedup is None
344-
):
341+
if not is_correct_at_t1[idx] or speedup_at_t1[idx] is None:
342+
fail_type_frozen = fail_type_at_t1[idx]
345343
rec_speedup_fake_degrad = fake_perf_degrad(
346-
t_key, es_status[idx], fpdb
344+
t_key, fail_type_frozen, fpdb
347345
)
348-
# print(f"sample: {sample.get('configuration').get('model')}, error type: {es_status[idx]}, rec_speedup_fake_degrad: {rec_speedup_fake_degrad}")
349-
else: # Still in a "CORRECT" state
346+
else:
350347
rec_speedup_fake_degrad = (
351-
speedup ** (negative_speedup_penalty + 1)
352-
if speedup < 1
353-
else speedup
348+
speedup_at_t1[idx] ** (negative_speedup_penalty + 1)
349+
if speedup_at_t1[idx] < 1
350+
else speedup_at_t1[idx]
354351
)
355352
rectified_speedups_fake_degrad.append(rec_speedup_fake_degrad)
356353

@@ -399,4 +396,3 @@ def print_stat_info(
399396
print(f" - pi: {pi}")
400397

401398
return s_scores, s_scores_fake_degrad
402-
return s_scores, es_scores

0 commit comments

Comments
 (0)