Skip to content

Commit 498f60d

Browse files
committed
style: apply black formatting to samples_statistics.py and verify_aggregated_params.py
1 parent 6033716 commit 498f60d

File tree

2 files changed

+21
-8
lines changed

2 files changed

+21
-8
lines changed

graph_net/samples_statistics.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,10 @@ def calculate_all_aggregated_parameters(
264264
# Convert dictionary-based pi and thresholds to indexed format for calculate_gamma
265265
# Create ordered list of error types for consistent indexing
266266
error_types = sorted(error_type_counts.keys())
267-
errno_tolerances = [error_tolerance_thresholds.get(error_type, 3) for error_type in error_types]
268-
267+
errno_tolerances = [
268+
error_tolerance_thresholds.get(error_type, 3) for error_type in error_types
269+
]
270+
269271
# Create get_pi function that maps error type index to pi value
270272
def get_pi(error_type_index: int) -> float:
271273
if error_type_index < len(error_types):
@@ -279,7 +281,9 @@ def get_pi(error_type_index: int) -> float:
279281
eta = calculate_eta(correct_speedups)
280282
gamma = calculate_gamma(t_key, get_pi, errno_tolerances, fpdb)
281283

282-
s_t = calculate_s_t_from_aggregated(alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb)
284+
s_t = calculate_s_t_from_aggregated(
285+
alpha, beta, lambda_, eta, negative_speedup_penalty, fpdb
286+
)
283287
es_t = calculate_es_t_from_aggregated(
284288
alpha, beta, lambda_, eta, gamma, negative_speedup_penalty
285289
)
@@ -294,4 +298,3 @@ def get_pi(error_type_index: int) -> float:
294298
"s_t": s_t,
295299
"es_t": es_t,
296300
}
297-

graph_net/verify_aggregated_params.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,9 +62,15 @@ def calculate_aggregated_parameters(
6262
]
6363

6464
# Filter correct samples and extract speedups
65-
correct_samples = [(idx, speedup) for idx, _, speedup, is_correct, _ in sample_data if is_correct]
65+
correct_samples = [
66+
(idx, speedup)
67+
for idx, _, speedup, is_correct, _ in sample_data
68+
if is_correct
69+
]
6670
correct_count = len(correct_samples)
67-
correct_speedups = [speedup for _, speedup in correct_samples if speedup is not None]
71+
correct_speedups = [
72+
speedup for _, speedup in correct_samples if speedup is not None
73+
]
6874
slowdown_speedups = [speedup for speedup in correct_speedups if speedup < 1]
6975
correct_negative_speedup_count = len(slowdown_speedups)
7076

@@ -80,7 +86,12 @@ def calculate_aggregated_parameters(
8086
# Store state at t=1 using list comprehension
8187
if t_key == 1:
8288
t1_data = [
83-
(idx, speedup, is_correct, fail_type if fail_type is not None else "CORRECT")
89+
(
90+
idx,
91+
speedup,
92+
is_correct,
93+
fail_type if fail_type is not None else "CORRECT",
94+
)
8495
for idx, _, speedup, is_correct, fail_type in sample_data
8596
]
8697
is_correct_at_t1 = [is_correct for _, _, is_correct, _ in t1_data]
@@ -250,4 +261,3 @@ def main():
250261

251262
if __name__ == "__main__":
252263
main()
253-

0 commit comments

Comments
 (0)