Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions codeflash/code_utils/config_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ def parse_config_file(
assert isinstance(config, dict)

# default values:
path_keys = ["module-root", "tests-root"]
path_list_keys = ["ignore-paths"]
path_keys = {"module-root", "tests-root"}
path_list_keys = {"ignore-paths", }
str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"}
bool_keys = {"disable-telemetry": False, "disable-imports-sorting": False}
list_str_keys = {"formatter-cmds": ["black $file"]}
Expand Down Expand Up @@ -83,7 +83,7 @@ def parse_config_file(
else: # Default to empty list
config[key] = []

assert config["test-framework"] in ["pytest", "unittest"], (
assert config["test-framework"] in {"pytest", "unittest"}, (
"In pyproject.toml, Codeflash only supports the 'test-framework' as pytest and unittest."
)
if len(config["formatter-cmds"]) > 0:
Expand Down
16 changes: 8 additions & 8 deletions codeflash/code_utils/tabulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def _pipe_segment_with_colons(align, colwidth):
"""Return a segment of a horizontal line with optional colons which
indicate column's alignment (as in `pipe` output format)."""
w = colwidth
if align in ["right", "decimal"]:
if align in {"right", "decimal"}:
return ("-" * (w - 1)) + ":"
elif align == "center":
return ":" + ("-" * (w - 2)) + ":"
Expand Down Expand Up @@ -176,7 +176,7 @@ def _isconvertible(conv, string):
def _isnumber(string):
return (
# fast path
type(string) in (float, int)
type(string) in {float, int}
# covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type
# convertible to int/float.
or (
Expand All @@ -188,7 +188,7 @@ def _isnumber(string):
# just an over/underflow
or (
not (math.isinf(float(string)) or math.isnan(float(string)))
or string.lower() in ["inf", "-inf", "nan"]
or string.lower() in {"inf", "-inf", "nan"}
)
)
)
Expand All @@ -210,7 +210,7 @@ def _isint(string, inttype=int):

def _isbool(string):
return type(string) is bool or (
isinstance(string, (bytes, str)) and string in ("True", "False")
isinstance(string, (bytes, str)) and string in {"True", "False"}
)


Expand Down Expand Up @@ -570,7 +570,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
# values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0)
keys = list(tabular_data)
if (
showindex in ["default", "always", True]
showindex in {"default", "always", True}
and tabular_data.index.name is not None
):
if isinstance(tabular_data.index.name, list):
Expand Down Expand Up @@ -686,7 +686,7 @@ def _normalize_tabular_data(tabular_data, headers, showindex="default"):
rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows))

# add or remove an index column
showindex_is_a_str = type(showindex) in [str, bytes]
showindex_is_a_str = type(showindex) in {str, bytes}
if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str):
pass

Expand Down Expand Up @@ -820,7 +820,7 @@ def tabulate(
if colglobalalign is not None: # if global alignment provided
aligns = [colglobalalign] * len(cols)
else: # default
aligns = [numalign if ct in [int, float] else stralign for ct in coltypes]
aligns = [numalign if ct in {int, float} else stralign for ct in coltypes]
# then specific alignments
if colalign is not None:
assert isinstance(colalign, Iterable)
Expand Down Expand Up @@ -1044,4 +1044,4 @@ def _format_table(
output = "\n".join(lines)
return output
else: # a completely empty table
return ""
return ""
8 changes: 4 additions & 4 deletions codeflash/code_utils/time_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ def humanize_runtime(time_in_ns: int) -> str:

units = re.split(r",|\s", runtime_human)[1]

if units in ("microseconds", "microsecond"):
if units in {"microseconds", "microsecond"}:
runtime_human = f"{time_micro:.3g}"
elif units in ("milliseconds", "millisecond"):
elif units in {"milliseconds", "millisecond"}:
runtime_human = "%.3g" % (time_micro / 1000)
elif units in ("seconds", "second"):
elif units in {"seconds", "second"}:
runtime_human = "%.3g" % (time_micro / (1000**2))
elif units in ("minutes", "minute"):
elif units in {"minutes", "minute"}:
runtime_human = "%.3g" % (time_micro / (60 * 1000**2))
else: # hours
runtime_human = "%.3g" % (time_micro / (3600 * 1000**2))
Expand Down
6 changes: 3 additions & 3 deletions codeflash/optimization/function_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -793,7 +793,7 @@ def establish_original_code_baseline(
line_profile_results = {"timings": {}, "unit": 0, "str_out": ""}
# For the original function - run the tests and get the runtime, plus coverage
with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"):
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}
success = True

test_env = os.environ.copy()
Expand Down Expand Up @@ -941,7 +941,7 @@ def run_optimized_candidate(
original_helper_code: dict[Path, str],
file_path_to_helper_classes: dict[Path, set[str]],
) -> Result[OptimizedCandidateResult, str]:
assert (test_framework := self.args.test_framework) in ["pytest", "unittest"]
assert (test_framework := self.args.test_framework) in {"pytest", "unittest"}

with progress_bar("Testing optimization candidate"):
test_env = os.environ.copy()
Expand Down Expand Up @@ -1118,7 +1118,7 @@ def run_and_parse_tests(
f"stdout: {run_result.stdout}\n"
f"stderr: {run_result.stderr}\n"
)
if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]:
if testing_type in {TestingMode.BEHAVIOR, TestingMode.PERFORMANCE}:
results, coverage_results = parse_test_results(
test_xml_path=result_file_path,
test_files=test_files,
Expand Down
8 changes: 4 additions & 4 deletions codeflash/tracing/profile_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class ProfileStats(pstats.Stats):
def __init__(self, trace_file_path: str, time_unit: str = "ns") -> None:
assert Path(trace_file_path).is_file(), f"Trace file {trace_file_path} does not exist"
assert time_unit in ["ns", "us", "ms", "s"], f"Invalid time unit {time_unit}"
assert time_unit in {"ns", "us", "ms", "s"}, f"Invalid time unit {time_unit}"
self.trace_file_path = trace_file_path
self.time_unit = time_unit
logger.debug(hasattr(self, "create_stats"))
Expand Down Expand Up @@ -59,10 +59,10 @@ def print_stats(self, *amount):
time_unit = {"ns": "nanoseconds", "us": "microseconds", "ms": "milliseconds", "s": "seconds"}[self.time_unit]
print(f"in {self.total_tt:.3f} {time_unit}", file=self.stream)
print(file=self.stream)
width, list = self.get_print_list(amount)
if list:
width, list_ = self.get_print_list(amount)
if list_:
self.print_title()
for func in list:
for func in list_:
self.print_line(func)
print(file=self.stream)
print(file=self.stream)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/tracing/replay_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def get_function_alias(module: str, function_name: str) -> str:
def create_trace_replay_test(
trace_file: str, functions: list[FunctionModules], test_framework: str = "pytest", max_run_count=100
) -> str:
assert test_framework in ["pytest", "unittest"]
assert test_framework in {"pytest", "unittest"}

imports = f"""import dill as pickle
{"import unittest" if test_framework == "unittest" else ""}
Expand Down
2 changes: 1 addition & 1 deletion codeflash/verification/comparator.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def comparator(orig: Any, new: Any, superset_obj=False) -> bool:
new_keys = {k: v for k, v in new.__dict__.items() if k != "parent"}
return comparator(orig_keys, new_keys, superset_obj)

if type(orig) in [types.BuiltinFunctionType, types.BuiltinMethodType]:
if type(orig) in {types.BuiltinFunctionType, types.BuiltinMethodType}:
return new == orig
if str(type(orig)) == "<class 'object'>":
return True
Expand Down
4 changes: 2 additions & 2 deletions codeflash/verification/equivalence.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
superset_obj = False
if original_test_result.verification_type and (
original_test_result.verification_type
in (VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO)
in {VerificationType.INIT_STATE_HELPER, VerificationType.INIT_STATE_FTO}
):
superset_obj = True
if not comparator(original_test_result.return_value, cdd_test_result.return_value, superset_obj=superset_obj):
Expand All @@ -67,7 +67,7 @@ def compare_test_results(original_results: TestResults, candidate_results: TestR
are_equal = False
break

if original_test_result.test_type in [TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST] and (
if original_test_result.test_type in {TestType.EXISTING_UNIT_TEST, TestType.CONCOLIC_COVERAGE_TEST} and (
cdd_test_result.did_pass != original_test_result.did_pass
):
are_equal = False
Expand Down
18 changes: 9 additions & 9 deletions codeflash/verification/parse_line_profile_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ def show_func(filename, start_lineno, func_name, timings, unit):
return ''
scalar = 1
if os.path.exists(filename):
out_table+=f'## Function: {func_name}\n'
out_table += f'## Function: {func_name}\n'
# Clear the cache to ensure that we get up-to-date results.
linecache.clearcache()
all_lines = linecache.getlines(filename)
sublines = inspect.getblock(all_lines[start_lineno - 1:])
out_table+='## Total time: %g s\n' % (total_time * unit)
out_table += '## Total time: %g s\n' % (total_time * unit)
# Define minimum column sizes so text fits and usually looks consistent
default_column_sizes = {
'hits': 9,
Expand Down Expand Up @@ -57,20 +57,20 @@ def show_func(filename, start_lineno, func_name, timings, unit):
if 'def' in line_ or nhits!='':
table_rows.append((nhits, time, per_hit, percent, line_))
pass
out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
out_table += tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True)
out_table+='\n'
return out_table

def show_text(stats: dict) -> str:
""" Show text for the given timings.
"""
out_table = ""
out_table+='# Timer unit: %g s\n' % stats['unit']
out_table += '# Timer unit: %g s\n' % stats['unit']
stats_order = sorted(stats['timings'].items())
# Show detailed per-line information for each function.
for (fn, lineno, name), timings in stats_order:
table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
out_table+=table_md
table_md = show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit'])
out_table += table_md
return out_table

def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict:
Expand All @@ -83,6 +83,6 @@ def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dic
stats = pickle.load(f)
stats_dict['timings'] = stats.timings
stats_dict['unit'] = stats.unit
str_out=show_text(stats_dict)
stats_dict['str_out']=str_out
return stats_dict, None
str_out = show_text(stats_dict)
stats_dict['str_out'] = str_out
return stats_dict, None
2 changes: 1 addition & 1 deletion codeflash/verification/parse_test_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def parse_sqlite_test_results(sqlite_file_path: Path, test_files: TestFiles, tes
iteration_id = val[5]
runtime = val[6]
verification_type = val[8]
if verification_type in (VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER):
if verification_type in {VerificationType.INIT_STATE_FTO, VerificationType.INIT_STATE_HELPER}:
test_type = TestType.INIT_STATE_TEST
else:
# TODO : this is because sqlite writes original file module path. Should make it consistent
Expand Down
4 changes: 2 additions & 2 deletions codeflash/verification/pytest_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ def pytest_runtestloop(self, session: Session) -> bool:
count += 1
total_time = self._get_total_time(session)

for index, item in enumerate(session.items):
for index, item in enumerate(session.items, 1):
item: pytest.Item = item # noqa: PLW0127, PLW2901
item._report_sections.clear() # clear reports for new test # noqa: SLF001

if total_time > SHORTEST_AMOUNT_OF_TIME:
item._nodeid = self._set_nodeid(item._nodeid, count) # noqa: SLF001

next_item: pytest.Item = session.items[index + 1] if index + 1 < len(session.items) else None
next_item: pytest.Item = session.items[index] if index < len(session.items) else None

self._clear_lru_caches(item)

Expand Down
4 changes: 2 additions & 2 deletions codeflash/verification/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def run_line_profile_tests(
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
Expand Down Expand Up @@ -224,7 +224,7 @@ def run_benchmarking_tests(
)
test_files: list[str] = []
for file in test_paths.test_files:
if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file:
if file.test_type in {TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST} and file.tests_in_file:
test_files.extend(
[
str(file.benchmarking_file_path)
Expand Down
2 changes: 1 addition & 1 deletion codeflash/verification/verification_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


def get_test_file_path(test_dir: Path, function_name: str, iteration: int = 0, test_type: str = "unit") -> Path:
assert test_type in ["unit", "inspired", "replay", "perf"]
assert test_type in {"unit", "inspired", "replay", "perf"}
function_name = function_name.replace(".", "_")
path = test_dir / f"test_{function_name}__{test_type}_test_{iteration}.py"
if path.exists():
Expand Down
Loading