Skip to content

Commit 97b0b1d

Browse files
committed
replace list with set for __include__, add space for '='
1 parent abde333 commit 97b0b1d

File tree

13 files changed

+41
-41
lines changed

13 files changed

+41
-41
lines changed

codeflash/code_utils/config_parser.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ def parse_config_file(
5252
assert isinstance(config, dict)
5353

5454
# default values:
55-
path_keys = ["module-root", "tests-root"]
56-
path_list_keys = ["ignore-paths"]
55+
path_keys = {"module-root", "tests-root"}
56+
path_list_keys = {"ignore-paths", }
5757
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
5858
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False}
5959
list_str_keys = {"formatter-cmds": ["black $file"]}
@@ -83,7 +83,7 @@ def parse_config_file(
8383
else: # Default to empty list
8484
config[key] = []
8585

86-
assert config["test-framework"] in ["pytest", "unittest"], (
86+
assert config["test-framework"] in {"pytest", "unittest"}, (
8787
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
8888
)
8989
if len(config["formatter-cmds"]) > 0:

codeflash/code_utils/tabulate.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def _pipe_segment_with_colons(align, colwidth):
7070
"""Return a segment of a horizontal line with optional colons which
7171
indicate column's alignment (as in `pipe` output format)."""
7272
w = colwidth
73-
if align in ["right", "decimal"]:
73+
if align in {"right", "decimal"}:
7474
return ("-" * (w - 1)) + ":"
7575
elif align == "center":
7676
return ":" + ("-" * (w - 2)) + ":"
@@ -176,7 +176,7 @@ def _isconvertible(conv, string):
176176
def _isnumber(string):
177177
return (
178178
# fast path
179-
type(string) in (float, int)
179+
type(string) in {float, int}
180180
# covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type
181181
# convertible to int/float.
182182
or (
@@ -188,7 +188,7 @@ def _isnumber(string):
188188
# just an over/underflow
189189
or (
190190
not (math.isinf(float(string)) or math.isnan(float(string)))
191-
or string.lower() in ["inf", "-inf", "nan"]
191+
or string.lower() in {"inf", "-inf", "nan"}
192192
)
193193
)
194194
)
@@ -210,7 +210,7 @@ def _isint(string, inttype=int):
210210

211211
def _isbool(string):
212212
return type(string) is bool or (
213-
isinstance(string, (bytes, str)) and string in ("True", "False")
213+
isinstance(string, (bytes, str)) and string in {"True", "False"}
214214
)
215215

216216

@@ -570,7 +570,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
570570
# values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)
571571
keys = list(tabular_data)
572572
if (
573-
showindex in ["default", "always", True]
573+
showindex in {"default", "always", True}
574574
and tabular_data.index.name is not None
575575
):
576576
if isinstance(tabular_data.index.name, list):
@@ -686,7 +686,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
686686
rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows))
687687

688688
# add or remove an index column
689-
showindex_is_a_str = type(showindex) in [str, bytes]
689+
showindex_is_a_str = type(showindex) in {str, bytes}
690690
if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str):
691691
pass
692692

@@ -820,7 +820,7 @@ def tabulate(
820820
if colglobalalign is not None: # if global alignment provided
821821
aligns = [colglobalalign] * len(cols)
822822
else: # default
823-
aligns = [numalign if ct in [int, float] else stralign for ct in coltypes]
823+
aligns = [numalign if ct in {int, float} else stralign for ct in coltypes]
824824
# then specific alignments
825825
if colalign is not None:
826826
assert isinstance(colalign, Iterable)
@@ -1044,4 +1044,4 @@ def _format_table(
10441044
output = "\n".join(lines)
10451045
return output
10461046
else: # a completely empty table
1047-
return ""
1047+
return ""

codeflash/code_utils/time_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str:
1616

1717
units = re.split(r",|\s", runtime_human)[1]
1818

19-
if units in ("microseconds", "microsecond"):
19+
if units in {"microseconds", "microsecond"}:
2020
runtime_human = f"{time_micro:.3g}"
21-
elif units in ("milliseconds", "millisecond"):
21+
elif units in {"milliseconds", "millisecond"}:
2222
runtime_human = "%.3g" % (time_micro / 1000)
23-
elif units in ("seconds", "second"):
23+
elif units in {"seconds", "second"}:
2424
runtime_human = "%.3g" % (time_micro / (1000**2))
25-
elif units in ("minutes", "minute"):
25+
elif units in {"minutes", "minute"}:
2626
runtime_human = "%.3g" % (time_micro / (60 * 1000**2))
2727
else: # hours
2828
runtime_human = "%.3g" % (time_micro / (3600 * 1000**2))

codeflash/optimization/function_optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ def establish_original_code_baseline(
793793
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
794794
# For the original function - run the tests and get the runtime, plus coverage
795795
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
796-
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
796+
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
797797
success = True
798798

799799
test_env = os.environ.copy()
@@ -941,7 +941,7 @@ def run_optimized_candidate(
941941
original_helper_code: dict[Path, str],
942942
file_path_to_helper_classes: dict[Path, set[str]],
943943
) -> Result[OptimizedCandidateResult, str]:
944-
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
944+
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
945945

946946
with progress_bar("Testing optimization candidate"):
947947
test_env = os.environ.copy()
@@ -1118,7 +1118,7 @@ def run_and_parse_tests(
11181118
f"stdout: {run_result.stdout}\n"
11191119
f"stderr: {run_result.stderr}\n"
11201120
)
1121-
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
1121+
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
11221122
results, coverage_results = parse_test_results(
11231123
test_xml_path=result_file_path,
11241124
test_files=test_files,

codeflash/tracing/profile_stats.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
class ProfileStats(pstats.Stats):
1111
def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
1212
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
13-
assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}"
13+
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
1414
self.trace_file_path = trace_file_path
1515
self.time_unit = time_unit
1616
logger.debug(hasattr(self, "create_stats"))
@@ -59,10 +59,10 @@ def print_stats(self, *amount):
5959
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
6060
print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream)
6161
print(file=self.stream)
62-
width, list = self.get_print_list(amount)
63-
if list:
62+
width, list_ = self.get_print_list(amount)
63+
if list_:
6464
self.print_title()
65-
for func in list:
65+
for func in list_:
6666
self.print_line(func)
6767
print(file=self.stream)
6868
print(file=self.stream)

codeflash/tracing/replay_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def get_function_alias(module: str, function_name: str) -> str:
4242
def create_trace_replay_test(
4343
trace_file: str, functions: list[FunctionModules], test_framework: str = "pytest", max_run_count=100
4444
) -> str:
45-
assert test_framework in ["pytest", "unittest"]
45+
assert test_framework in {"pytest", "unittest"}
4646

4747
imports = f"""import dill as pickle
4848
{"import unittest" if test_framework == "unittest" else ""}

codeflash/verification/comparator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
233233
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
234234
return comparator(orig_keys, new_keys, superset_obj)
235235

236-
if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
236+
if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}:
237237
return new == orig
238238
if str(type(orig)) == "<class 'object'>":
239239
return True

codeflash/verification/equivalence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
4040
superset_obj = False
4141
if original_test_result.verification_type and (
4242
original_test_result.verification_type
43-
in (VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO)
43+
in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO}
4444
):
4545
superset_obj = True
4646
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
@@ -67,7 +67,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
6767
are_equal = False
6868
break
6969

70-
if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (
70+
if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST} and (
7171
cdd_test_result.did_pass != original_test_result.did_pass
7272
):
7373
are_equal = False

codeflash/verification/parse_line_profile_test_output.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ def show_func(filename, start_lineno, func_name, timings, unit):
1616
return ''
1717
scalar = 1
1818
if os.path.exists(filename):
19-
out_table+=f'## Function: {func_name}\n'
19+
out_table += f'## Function: {func_name}\n'
2020
# Clear the cache to ensure that we get up-to-date results.
2121
linecache.clearcache()
2222
all_lines = linecache.getlines(filename)
2323
sublines = inspect.getblock(all_lines[start_lineno - 1:])
24-
out_table+='## Total time: %g s\n' % (total_time * unit)
24+
out_table += '## Total time: %g s\n' % (total_time * unit)
2525
# Define minimum column sizes so text fits and usually looks consistent
2626
default_column_sizes = {
2727
'hits': 9,
@@ -57,20 +57,20 @@ def show_func(filename, start_lineno, func_name, timings, unit):
5757
if 'def' in line_ or nhits!='':
5858
table_rows.append((nhits, time, per_hit, percent, line_))
5959
pass
60-
out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
60+
out_table += tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
6161
out_table+='\n'
6262
return out_table
6363

6464
def show_text(stats: dict) -> str:
6565
""" Show text for the given timings.
6666
"""
6767
out_table = ""
68-
out_table+='# Timer unit: %g s\n' % stats['unit']
68+
out_table += '# Timer unit: %g s\n' % stats['unit']
6969
stats_order = sorted(stats['timings'].items())
7070
# Show detailed per-line information for each function.
7171
for (fn, lineno, name), timings in stats_order:
72-
table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
73-
out_table+=table_md
72+
table_md = show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
73+
out_table += table_md
7474
return out_table
7575

7676
def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict:
@@ -83,6 +83,6 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic
8383
stats = pickle.load(f)
8484
stats_dict['timings'] = stats.timings
8585
stats_dict['unit'] = stats.unit
86-
str_out=show_text(stats_dict)
87-
stats_dict['str_out']=str_out
88-
return stats_dict, None
86+
str_out = show_text(stats_dict)
87+
stats_dict['str_out'] = str_out
88+
return stats_dict, None

codeflash/verification/parse_test_output.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
127127
iteration_id = val[5]
128128
runtime = val[6]
129129
verification_type = val[8]
130-
if verification_type in (VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER):
130+
if verification_type in {VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER}:
131131
test_type = TestType.INIT_STATE_TEST
132132
else:
133133
# TODO : this is because sqlite writes original file module path. Should make it consistent

0 commit comments

Comments
 (0)