@@ -254,6 +254,10 @@ def print_stat_info(
254254 # pi is a list of constants for t > 0 for each group
255255 pi = [0 , 0 ]
256256
257+ is_correct_at_t1 = [False ] * total_samples
258+ speedup_at_t1 = [None ] * total_samples
259+ fail_type_at_t1 = ["CORRECT" ] * total_samples
260+
257261 final_correct_count = 0
258262 final_correct_negative_speedup_count = 0
259263 final_correct_speedups = []
@@ -291,8 +295,8 @@ def print_stat_info(
291295 get_correctness (eager_dtypes [i ], t_key , correctness_data , i )
292296 for i in range (output_count )
293297 )
294- if not is_correct :
295- fail_type = "accuracy"
298+ if not is_correct :
299+ fail_type = "accuracy"
296300
297301 # Collect statistics
298302 if is_correct :
@@ -306,6 +310,11 @@ def print_stat_info(
306310 if fail_type == "accuracy" :
307311 acc_failure_count += 1
308312
313+ if t_key == 1 :
314+ is_correct_at_t1 [idx ] = is_correct
315+ speedup_at_t1 [idx ] = speedup
316+ fail_type_at_t1 [idx ] = fail_type if fail_type is not None else "CORRECT"
317+
309318 # S(t) calculation
310319 if fail_type is not None or speedup is None :
311320 regularized_speedup = fpdb
@@ -320,37 +329,25 @@ def print_stat_info(
320329 # ES(t) calculation: based on state change
321330 rec_speedup_fake_degrad = 0
322331 if t_key < 1 :
323- # When t < 1, ES behaves the same as S
324332 if fail_type is not None or speedup is None :
325333 rec_speedup_fake_degrad = fpdb
326- # print(f"sample: {sample.get('configuration').get('model')}, fail_type: {fail_type}, rec_speedup_fake_degrad: {rec_speedup_fake_degrad}")
327334 else :
328335 rec_speedup_fake_degrad = (
329336 speedup ** (negative_speedup_penalty + 1 )
330337 if speedup < 1
331338 else speedup
332339 )
333340 else :
334- # When t >= 1, ES starts applying stepwise logic
335- # ES curve's stepwise state, initialized as 'CORRECT'
336- es_status = ["CORRECT" ] * total_samples
337- if es_status [idx ] == "CORRECT" and fail_type is not None :
338- es_status [idx ] = fail_type
339-
340- if (
341- es_status [idx ] is not None
342- and es_status [idx ] != "CORRECT"
343- or speedup is None
344- ):
341+ if not is_correct_at_t1 [idx ] or speedup_at_t1 [idx ] is None :
342+ fail_type_frozen = fail_type_at_t1 [idx ]
345343 rec_speedup_fake_degrad = fake_perf_degrad (
346- t_key , es_status [ idx ] , fpdb
344+ t_key , fail_type_frozen , fpdb
347345 )
348- # print(f"sample: {sample.get('configuration').get('model')}, error type: {es_status[idx]}, rec_speedup_fake_degrad: {rec_speedup_fake_degrad}")
349- else : # Still in a "CORRECT" state
346+ else :
350347 rec_speedup_fake_degrad = (
351- speedup ** (negative_speedup_penalty + 1 )
352- if speedup < 1
353- else speedup
348+ speedup_at_t1 [ idx ] ** (negative_speedup_penalty + 1 )
349+ if speedup_at_t1 [ idx ] < 1
350+ else speedup_at_t1 [ idx ]
354351 )
355352 rectified_speedups_fake_degrad .append (rec_speedup_fake_degrad )
356353
@@ -399,4 +396,3 @@ def print_stat_info(
399396 print (f" - pi: { pi } " )
400397
401398 return s_scores , s_scores_fake_degrad
402- return s_scores , es_scores
0 commit comments