diff --git a/code_to_optimize/bubble_sort_classmethod.py b/code_to_optimize/bubble_sort_classmethod.py new file mode 100644 index 000000000..c1cac98b7 --- /dev/null +++ b/code_to_optimize/bubble_sort_classmethod.py @@ -0,0 +1,6 @@ +from code_to_optimize.bubble_sort_in_class import BubbleSortClass + + +def sort_classmethod(x): + y = BubbleSortClass() + return y.sorter(x) diff --git a/code_to_optimize/bubble_sort_nested_classmethod.py b/code_to_optimize/bubble_sort_nested_classmethod.py new file mode 100644 index 000000000..19bef4ad2 --- /dev/null +++ b/code_to_optimize/bubble_sort_nested_classmethod.py @@ -0,0 +1,6 @@ +from code_to_optimize.bubble_sort_in_nested_class import WrapperClass + + +def sort_classmethod(x): + y = WrapperClass.BubbleSortClass() + return y.sorter(x) diff --git a/codeflash/api/aiservice.py b/codeflash/api/aiservice.py index f1c6bcf9d..fddc5c18a 100644 --- a/codeflash/api/aiservice.py +++ b/codeflash/api/aiservice.py @@ -135,6 +135,76 @@ def optimize_python_code( console.rule() return [] + def optimize_python_code_line_profiler( + self, + source_code: str, + dependency_code: str, + trace_id: str, + line_profiler_results: str, + num_candidates: int = 10, + experiment_metadata: ExperimentMetadata | None = None, + ) -> list[OptimizedCandidate]: + """Optimize the given python code for performance by making a request to the Django endpoint. + + Parameters + ---------- + - source_code (str): The python code to optimize. + - dependency_code (str): The dependency code used as read-only context for the optimization + - trace_id (str): Trace id of optimization run + - num_candidates (int): Number of optimization variants to generate. Default is 10. + - experiment_metadata (Optional[ExperimentalMetadata, None]): Any available experiment metadata for this optimization + + Returns + ------- + - List[OptimizationCandidate]: A list of Optimization Candidates. + + """ + payload = { + "source_code": source_code, + "dependency_code": dependency_code, + "num_variants": num_candidates, + "line_profiler_results": line_profiler_results, + "trace_id": trace_id, + "python_version": platform.python_version(), + "experiment_metadata": experiment_metadata, + "codeflash_version": codeflash_version, + } + + logger.info("Generating optimized candidates…") + console.rule() + if line_profiler_results=="": + logger.info("No LineProfiler results were provided, Skipping optimization.") + console.rule() + return [] + try: + response = self.make_ai_service_request("/optimize-line-profiler", payload=payload, timeout=600) + except requests.exceptions.RequestException as e: + logger.exception(f"Error generating optimized candidates: {e}") + ph("cli-optimize-error-caught", {"error": str(e)}) + return [] + + if response.status_code == 200: + optimizations_json = response.json()["optimizations"] + logger.info(f"Generated {len(optimizations_json)} candidates.") + console.rule() + return [ + OptimizedCandidate( + source_code=opt["source_code"], + explanation=opt["explanation"], + optimization_id=opt["optimization_id"], + ) + for opt in optimizations_json + ] + try: + error = response.json()["error"] + except Exception: + error = response.text + logger.error(f"Error generating optimized candidates: {response.status_code} - {error}") + ph("cli-optimize-error-response", {"response_status_code": response.status_code, "error": error}) + console.rule() + return [] + + def log_results( self, function_trace_id: str, diff --git a/codeflash/code_utils/line_profile_utils.py b/codeflash/code_utils/line_profile_utils.py new file mode 100644 index 000000000..21768cf68 --- /dev/null +++ b/codeflash/code_utils/line_profile_utils.py @@ -0,0 +1,223 @@ +"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" +from collections import defaultdict +from pathlib import Path +from typing import Union + +import isort +import libcst as cst + +from codeflash.code_utils.code_utils import get_run_tmp_file + + +class LineProfilerDecoratorAdder(cst.CSTTransformer): + """Transformer that adds a decorator to a function with a specific qualified name.""" + + #TODO we don't support nested functions yet so they can only be inside classes, dont use qualified names, instead use the structure + def __init__(self, qualified_name: str, decorator_name: str): + """Initialize the transformer. + + Args: + qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). + decorator_name: The name of the decorator to add. + + """ + super().__init__() + self.qualified_name_parts = qualified_name.split(".") + self.decorator_name = decorator_name + + # Track our current context path, only add when we encounter a class + self.context_stack = [] + + def visit_ClassDef(self, node: cst.ClassDef) -> None: + # Track when we enter a class + self.context_stack.append(node.name.value) + + def leave_ClassDef(self, original_node: cst.ClassDef, updated_node: cst.ClassDef) -> cst.ClassDef: + # Pop the context when we leave a class + self.context_stack.pop() + return updated_node + + def visit_FunctionDef(self, node: cst.FunctionDef) -> None: + # Track when we enter a function + self.context_stack.append(node.name.value) + + def leave_FunctionDef(self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef) -> cst.FunctionDef: + function_name = original_node.name.value + + # Check if the current context path matches our target qualified name + if self.context_stack==self.qualified_name_parts: + # Check if the decorator is already present + has_decorator = any( + self._is_target_decorator(decorator.decorator) + for decorator in original_node.decorators + ) + + # Only add the decorator if it's not already there + if not has_decorator: + new_decorator = cst.Decorator( + decorator=cst.Name(value=self.decorator_name) + ) + + # Add our new decorator to the existing decorators + updated_decorators = [new_decorator] + list(updated_node.decorators) + updated_node = updated_node.with_changes( + decorators=tuple(updated_decorators) + ) + + # Pop the context when we leave a function + self.context_stack.pop() + return updated_node + + def _is_target_decorator(self, decorator_node: Union[cst.Name, cst.Attribute, cst.Call]) -> bool: + """Check if a decorator matches our target decorator name.""" + if isinstance(decorator_node, cst.Name): + return decorator_node.value == self.decorator_name + if isinstance(decorator_node, cst.Call) and isinstance(decorator_node.func, cst.Name): + return decorator_node.func.value == self.decorator_name + return False + +class ProfileEnableTransformer(cst.CSTTransformer): + def __init__(self,filename): + # Flag to track if we found the import statement + self.found_import = False + # Track indentation of the import statement + self.import_indentation = None + self.filename = filename + + def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom: + # Check if this is the line profiler import statement + if (isinstance(original_node.module, cst.Name) and + original_node.module.value == "line_profiler" and + any(name.name.value == "profile" and + (not name.asname or name.asname.name.value == "codeflash_line_profile") + for name in original_node.names)): + + self.found_import = True + # Get the indentation from the original node + if hasattr(original_node, "leading_lines"): + leading_whitespace = original_node.leading_lines[-1].whitespace if original_node.leading_lines else "" + self.import_indentation = leading_whitespace + + return updated_node + + def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> cst.Module: + if not self.found_import: + return updated_node + + # Create a list of statements from the original module + new_body = list(updated_node.body) + + # Find the index of the import statement + import_index = None + for i, stmt in enumerate(new_body): + if isinstance(stmt, cst.SimpleStatementLine): + for small_stmt in stmt.body: + if isinstance(small_stmt, cst.ImportFrom): + if (isinstance(small_stmt.module, cst.Name) and + small_stmt.module.value == "line_profiler" and + any(name.name.value == "profile" and + (not name.asname or name.asname.name.value == "codeflash_line_profile") + for name in small_stmt.names)): + import_index = i + break + if import_index is not None: + break + + if import_index is not None: + # Create the new enable statement to insert after the import + enable_statement = cst.parse_statement( + f"codeflash_line_profile.enable(output_prefix='{self.filename}')" + ) + + # Insert the new statement after the import statement + new_body.insert(import_index + 1, enable_statement) + + # Create a new module with the updated body + return updated_node.with_changes(body=new_body) + +def add_decorator_to_qualified_function(module, qualified_name, decorator_name): + """Add a decorator to a function with the exact qualified name in the source code. + + Args: + module: The Python source code as a string. + qualified_name: The fully qualified name of the function to add the decorator to (e.g., "MyClass.nested_func.target_func"). + decorator_name: The name of the decorator to add. + + Returns: + The modified source code as a string. + + """ + # Parse the source code into a CST + + # Apply our transformer + transformer = LineProfilerDecoratorAdder(qualified_name, decorator_name) + modified_module = module.visit(transformer) + + # Convert the modified CST back to source code + return modified_module + +def add_profile_enable(original_code: str, line_profile_output_file: str) -> str: + # TODO modify by using a libcst transformer + module = cst.parse_module(original_code) + transformer = ProfileEnableTransformer(line_profile_output_file) + modified_module = module.visit(transformer) + return modified_module.code + + +class ImportAdder(cst.CSTTransformer): + def __init__(self, import_statement): + self.import_statement = import_statement + self.has_import = False + + def leave_Module(self, original_node, updated_node): + # If the import is already there, don't add it again + if self.has_import: + return updated_node + + # Parse the import statement into a CST node + import_node = cst.parse_statement(self.import_statement) + + # Add the import to the module's body + return updated_node.with_changes( + body=[import_node] + list(updated_node.body) + ) + + def visit_ImportFrom(self, node): + # Check if the profile is already imported from line_profiler + if node.module and node.module.value == "line_profiler": + for import_alias in node.names: + if import_alias.name.value == "profile": + self.has_import = True + + +def add_decorator_imports(function_to_optimize, code_context): + """Adds a profile decorator to a function in a Python file and all its helper functions.""" + #self.function_to_optimize, file_path_to_helper_classes, self.test_cfg.tests_root + #grouped iteration, file to fns to optimize, from line_profiler import profile as codeflash_line_profile + file_paths = defaultdict(list) + line_profile_output_file = get_run_tmp_file(Path("baseline_lprof")) + file_paths[function_to_optimize.file_path].append(function_to_optimize.qualified_name) + for elem in code_context.helper_functions: + file_paths[elem.file_path].append(elem.qualified_name) + for file_path,fns_present in file_paths.items(): + #open file + file_contents = file_path.read_text("utf-8") + # parse to cst + module_node = cst.parse_module(file_contents) + for fn_name in fns_present: + # add decorator + module_node = add_decorator_to_qualified_function(module_node, fn_name, "codeflash_line_profile") + # add imports + # Create a transformer to add the import + transformer = ImportAdder("from line_profiler import profile as codeflash_line_profile") + # Apply the transformer to add the import + module_node = module_node.visit(transformer) + modified_code = isort.code(module_node.code, float_to_top=True) + # write to file + with open(file_path, "w", encoding="utf-8") as file: + file.write(modified_code) + #Adding profile.enable line for changing the savepath of the data, do this only for the main file and not the helper files + file_contents = function_to_optimize.file_path.read_text("utf-8") + modified_code = add_profile_enable(file_contents,str(line_profile_output_file)) + function_to_optimize.file_path.write_text(modified_code,"utf-8") + return line_profile_output_file diff --git a/codeflash/code_utils/tabulate.py b/codeflash/code_utils/tabulate.py new file mode 100644 index 000000000..c278bfeae --- /dev/null +++ b/codeflash/code_utils/tabulate.py @@ -0,0 +1,1047 @@ +"""Adapted from tabulate (https://github.com/astanin/python-tabulate) written by Sergey Astanin and contributors (MIT License)""" + +"""Pretty-print tabular data.""" + +import warnings +from collections import namedtuple +from collections.abc import Iterable +from itertools import chain, zip_longest as izip_longest +from functools import reduce +import re +import math +import dataclasses +import wcwidth # optional wide-character (CJK) support + +__all__ = ["tabulate", "tabulate_formats"] + +# minimum extra space in headers +MIN_PADDING = 2 + +_DEFAULT_FLOATFMT = "g" +_DEFAULT_INTFMT = "" +_DEFAULT_MISSINGVAL = "" +# default align will be overwritten by "left", "center" or "decimal" +# depending on the formatter +_DEFAULT_ALIGN = "default" + + +# if True, enable wide-character (CJK) support +WIDE_CHARS_MODE = wcwidth is not None + +# Constant that can be used as part of passed rows to generate a separating line +# It is purposely an unprintable character, very unlikely to be used in a table +SEPARATING_LINE = "\001" + +Line = namedtuple("Line", ["begin", "hline", "sep", "end"]) + + +DataRow = namedtuple("DataRow", ["begin", "sep", "end"]) + +TableFormat = namedtuple( + "TableFormat", + [ + "lineabove", + "linebelowheader", + "linebetweenrows", + "linebelow", + "headerrow", + "datarow", + "padding", + "with_header_hide", + ], +) + + +def _is_separating_line_value(value): + return type(value) is str and value.strip() == SEPARATING_LINE + + +def _is_separating_line(row): + row_type = type(row) + is_sl = (row_type == list or row_type == str) and ( + (len(row) >= 1 and _is_separating_line_value(row[0])) + or (len(row) >= 2 and _is_separating_line_value(row[1])) + ) + + return is_sl + + +def _pipe_segment_with_colons(align, colwidth): + """Return a segment of a horizontal line with optional colons which + indicate column's alignment (as in `pipe` output format).""" + w = colwidth + if align in ["right", "decimal"]: + return ("-" * (w - 1)) + ":" + elif align == "center": + return ":" + ("-" * (w - 2)) + ":" + elif align == "left": + return ":" + ("-" * (w - 1)) + else: + return "-" * w + + +def _pipe_line_with_colons(colwidths, colaligns): + """Return a horizontal line with optional colons to indicate column's + alignment (as in `pipe` output format).""" + if not colaligns: # e.g. printing an empty data frame (github issue #15) + colaligns = [""] * len(colwidths) + segments = [_pipe_segment_with_colons(a, w) for a, w in zip(colaligns, colwidths)] + return "|" + "|".join(segments) + "|" + +_table_formats = { + "simple": TableFormat( + lineabove=Line("", "-", " ", ""), + linebelowheader=Line("", "-", " ", ""), + linebetweenrows=None, + linebelow=Line("", "-", " ", ""), + headerrow=DataRow("", " ", ""), + datarow=DataRow("", " ", ""), + padding=0, + with_header_hide=["lineabove", "linebelow"], + ), + "pipe": TableFormat( + lineabove=_pipe_line_with_colons, + linebelowheader=_pipe_line_with_colons, + linebetweenrows=None, + linebelow=None, + headerrow=DataRow("|", "|", "|"), + datarow=DataRow("|", "|", "|"), + padding=1, + with_header_hide=["lineabove"], + ), +} + +tabulate_formats = list(sorted(_table_formats.keys())) + +# The table formats for which multiline cells will be folded into subsequent +# table rows. The key is the original format specified at the API. The value is +# the format that will be used to represent the original format. +multiline_formats = { + "plain": "plain", + "pipe": "pipe", + +} + +_multiline_codes = re.compile(r"\r|\n|\r\n") +_multiline_codes_bytes = re.compile(b"\r|\n|\r\n") + +_esc = r"\x1b" +_csi = rf"{_esc}\[" +_osc = rf"{_esc}\]" +_st = rf"{_esc}\\" + +_ansi_escape_pat = rf""" + ( + # terminal colors, etc + {_csi} # CSI + [\x30-\x3f]* # parameter bytes + [\x20-\x2f]* # intermediate bytes + [\x40-\x7e] # final byte + | + # terminal hyperlinks + {_osc}8; # OSC opening + (\w+=\w+:?)* # key=value params list (submatch 2) + ; # delimiter + ([^{_esc}]+) # URI - anything but ESC (submatch 3) + {_st} # ST + ([^{_esc}]+) # link text - anything but ESC (submatch 4) + {_osc}8;;{_st} # "closing" OSC sequence + ) +""" +_ansi_codes = re.compile(_ansi_escape_pat, re.VERBOSE) +_ansi_codes_bytes = re.compile(_ansi_escape_pat.encode("utf8"), re.VERBOSE) +_ansi_color_reset_code = "\033[0m" + +_float_with_thousands_separators = re.compile( + r"^(([+-]?[0-9]{1,3})(?:,([0-9]{3}))*)?(?(1)\.[0-9]*|\.[0-9]+)?$" +) + +def _isnumber_with_thousands_separator(string): + try: + string = string.decode() + except (UnicodeDecodeError, AttributeError): + pass + + return bool(re.match(_float_with_thousands_separators, string)) + + +def _isconvertible(conv, string): + try: + conv(string) + return True + except (ValueError, TypeError): + return False + + +def _isnumber(string): + return ( + # fast path + type(string) in (float, int) + # covers 'NaN', +/- 'inf', and eg. '1e2', as well as any type + # convertible to int/float. + or ( + _isconvertible(float, string) + and ( + # some other type convertible to float + not isinstance(string, (str, bytes)) + # or, a numeric string eg. "1e1...", "NaN", ..., but isn't + # just an over/underflow + or ( + not (math.isinf(float(string)) or math.isnan(float(string))) + or string.lower() in ["inf", "-inf", "nan"] + ) + ) + ) + ) + + +def _isint(string, inttype=int): + return ( + type(string) is inttype + or ( + (hasattr(string, "is_integer") or hasattr(string, "__array__")) + and str(type(string)).startswith("= 0: + return len(string) - pos - 1 + else: + return -1 # no point + else: + return -1 # not a number + + +def _padleft(width, s): + fmt = "{0:>%ds}" % width + return fmt.format(s) + + +def _padright(width, s): + fmt = "{0:<%ds}" % width + return fmt.format(s) + + +def _padboth(width, s): + fmt = "{0:^%ds}" % width + return fmt.format(s) + + +def _padnone(ignore_width, s): + return s + + +def _strip_ansi(s): + if isinstance(s, str): + return _ansi_codes.sub(r"\4", s) + else: # a bytestring + return _ansi_codes_bytes.sub(r"\4", s) + + +def _visible_width(s): + if wcwidth is not None and WIDE_CHARS_MODE: + len_fn = wcwidth.wcswidth + else: + len_fn = len + if isinstance(s, (str, bytes)): + return len_fn(_strip_ansi(s)) + else: + return len_fn(str(s)) + + +def _is_multiline(s): + if isinstance(s, str): + return bool(re.search(_multiline_codes, s)) + else: # a bytestring + return bool(re.search(_multiline_codes_bytes, s)) + + +def _multiline_width(multiline_s, line_width_fn=len): + return max(map(line_width_fn, re.split("[\r\n]", multiline_s))) + + +def _choose_width_fn(has_invisible, enable_widechars, is_multiline): + if has_invisible: + line_width_fn = _visible_width + elif enable_widechars: # optional wide-character support if available + line_width_fn = wcwidth.wcswidth + else: + line_width_fn = len + if is_multiline: + width_fn = lambda s: _multiline_width(s, line_width_fn) # noqa + else: + width_fn = line_width_fn + return width_fn + + +def _align_column_choose_padfn(strings, alignment, has_invisible, preserve_whitespace): + if alignment == "right": + if not preserve_whitespace: + strings = [s.strip() for s in strings] + padfn = _padleft + elif alignment == "center": + if not preserve_whitespace: + strings = [s.strip() for s in strings] + padfn = _padboth + elif alignment == "decimal": + if has_invisible: + decimals = [_afterpoint(_strip_ansi(s)) for s in strings] + else: + decimals = [_afterpoint(s) for s in strings] + maxdecimals = max(decimals) + strings = [s + (maxdecimals - decs) * " " for s, decs in zip(strings, decimals)] + padfn = _padleft + elif not alignment: + padfn = _padnone + else: + if not preserve_whitespace: + strings = [s.strip() for s in strings] + padfn = _padright + return strings, padfn + + +def _align_column_choose_width_fn(has_invisible, enable_widechars, is_multiline): + if has_invisible: + line_width_fn = _visible_width + elif enable_widechars: # optional wide-character support if available + line_width_fn = wcwidth.wcswidth + else: + line_width_fn = len + if is_multiline: + width_fn = lambda s: _align_column_multiline_width(s, line_width_fn) # noqa + else: + width_fn = line_width_fn + return width_fn + + +def _align_column_multiline_width(multiline_s, line_width_fn=len): + return list(map(line_width_fn, re.split("[\r\n]", multiline_s))) + + +def _flat_list(nested_list): + ret = [] + for item in nested_list: + if isinstance(item, list): + ret.extend(item) + else: + ret.append(item) + return ret + + +def _align_column( + strings, + alignment, + minwidth=0, + has_invisible=True, + enable_widechars=False, + is_multiline=False, + preserve_whitespace=False, +): + strings, padfn = _align_column_choose_padfn( + strings, alignment, has_invisible, preserve_whitespace + ) + width_fn = _align_column_choose_width_fn( + has_invisible, enable_widechars, is_multiline + ) + + s_widths = list(map(width_fn, strings)) + maxwidth = max(max(_flat_list(s_widths)), minwidth) + # TODO: refactor column alignment in single-line and multiline modes + if is_multiline: + if not enable_widechars and not has_invisible: + padded_strings = [ + "\n".join([padfn(maxwidth, s) for s in ms.splitlines()]) + for ms in strings + ] + else: + # enable wide-character width corrections + s_lens = [[len(s) for s in re.split("[\r\n]", ms)] for ms in strings] + visible_widths = [ + [maxwidth - (w - l) for w, l in zip(mw, ml)] + for mw, ml in zip(s_widths, s_lens) + ] + # wcswidth and _visible_width don't count invisible characters; + # padfn doesn't need to apply another correction + padded_strings = [ + "\n".join([padfn(w, s) for s, w in zip((ms.splitlines() or ms), mw)]) + for ms, mw in zip(strings, visible_widths) + ] + else: # single-line cell values + if not enable_widechars and not has_invisible: + padded_strings = [padfn(maxwidth, s) for s in strings] + else: + # enable wide-character width corrections + s_lens = list(map(len, strings)) + visible_widths = [maxwidth - (w - l) for w, l in zip(s_widths, s_lens)] + # wcswidth and _visible_width don't count invisible characters; + # padfn doesn't need to apply another correction + padded_strings = [padfn(w, s) for s, w in zip(strings, visible_widths)] + return padded_strings + + +def _more_generic(type1, type2): + types = { + type(None): 0, + bool: 1, + int: 2, + float: 3, + bytes: 4, + str: 5, + } + invtypes = { + 5: str, + 4: bytes, + 3: float, + 2: int, + 1: bool, + 0: type(None), + } + moregeneric = max(types.get(type1, 5), types.get(type2, 5)) + return invtypes[moregeneric] + + +def _column_type(strings, has_invisible=True, numparse=True): + types = [_type(s, has_invisible, numparse) for s in strings] + return reduce(_more_generic, types, bool) + + +def _format(val, valtype, floatfmt, intfmt, missingval="", has_invisible=True): + if val is None: + return missingval + if isinstance(val, (bytes, str)) and not val: + return "" + + if valtype is str: + return f"{val}" + elif valtype is int: + if isinstance(val, str): + val_striped = val.encode("unicode_escape").decode("utf-8") + colored = re.search( + r"(\\[xX]+[0-9a-fA-F]+\[\d+[mM]+)([0-9.]+)(\\.*)$", val_striped + ) + if colored: + total_groups = len(colored.groups()) + if total_groups == 3: + digits = colored.group(2) + if digits.isdigit(): + val_new = ( + colored.group(1) + + format(int(digits), intfmt) + + colored.group(3) + ) + val = val_new.encode("utf-8").decode("unicode_escape") + intfmt = "" + return format(val, intfmt) + elif valtype is bytes: + try: + return str(val, "ascii") + except (TypeError, UnicodeDecodeError): + return str(val) + elif valtype is float: + is_a_colored_number = has_invisible and isinstance(val, (str, bytes)) + if is_a_colored_number: + raw_val = _strip_ansi(val) + formatted_val = format(float(raw_val), floatfmt) + return val.replace(raw_val, formatted_val) + else: + if isinstance(val, str) and "," in val: + val = val.replace(",", "") # handle thousands-separators + return format(float(val), floatfmt) + else: + return f"{val}" + + +def _align_header( + header, alignment, width, visible_width, is_multiline=False, width_fn=None +): + "Pad string header to width chars given known visible_width of the header." + if is_multiline: + header_lines = re.split(_multiline_codes, header) + padded_lines = [ + _align_header(h, alignment, width, width_fn(h)) for h in header_lines + ] + return "\n".join(padded_lines) + # else: not multiline + ninvisible = len(header) - visible_width + width += ninvisible + if alignment == "left": + return _padright(width, header) + elif alignment == "center": + return _padboth(width, header) + elif not alignment: + return f"{header}" + else: + return _padleft(width, header) + + +def _remove_separating_lines(rows): + if isinstance(rows, list): + separating_lines = [] + sans_rows = [] + for index, row in enumerate(rows): + if _is_separating_line(row): + separating_lines.append(index) + else: + sans_rows.append(row) + return sans_rows, separating_lines + else: + return rows, None + +def _bool(val): + "A wrapper around standard bool() which doesn't throw on NumPy arrays" + try: + return bool(val) + except ValueError: # val is likely to be a numpy array with many elements + return False + + +def _normalize_tabular_data(tabular_data, headers, showindex="default"): + try: + bool(headers) + except ValueError: # numpy.ndarray, pandas.core.index.Index, ... + headers = list(headers) + + err_msg = ( + "\n\nTo build a table python-tabulate requires two-dimensional data " + "like a list of lists or similar." + "\nDid you forget a pair of extra [] or ',' in ()?" + ) + index = None + if hasattr(tabular_data, "keys") and hasattr(tabular_data, "values"): + # dict-like and pandas.DataFrame? + if hasattr(tabular_data.values, "__call__"): + # likely a conventional dict + keys = tabular_data.keys() + try: + rows = list( + izip_longest(*tabular_data.values()) + ) # columns have to be transposed + except TypeError: # not iterable + raise TypeError(err_msg) + + elif hasattr(tabular_data, "index"): + # values is a property, has .index => it's likely a pandas.DataFrame (pandas 0.11.0) + keys = list(tabular_data) + if ( + showindex in ["default", "always", True] + and tabular_data.index.name is not None + ): + if isinstance(tabular_data.index.name, list): + keys[:0] = tabular_data.index.name + else: + keys[:0] = [tabular_data.index.name] + vals = tabular_data.values # values matrix doesn't need to be transposed + # for DataFrames add an index per default + index = list(tabular_data.index) + rows = [list(row) for row in vals] + else: + raise ValueError("tabular data doesn't appear to be a dict or a DataFrame") + + if headers == "keys": + headers = list(map(str, keys)) # headers should be strings + + else: # it's a usual iterable of iterables, or a NumPy array, or an iterable of dataclasses + try: + rows = list(tabular_data) + except TypeError: # not iterable + raise TypeError(err_msg) + + if headers == "keys" and not rows: + # an empty table (issue #81) + headers = [] + elif ( + headers == "keys" + and hasattr(tabular_data, "dtype") + and getattr(tabular_data.dtype, "names") + ): + # numpy record array + headers = tabular_data.dtype.names + elif ( + headers == "keys" + and len(rows) > 0 + and isinstance(rows[0], tuple) + and hasattr(rows[0], "_fields") + ): + # namedtuple + headers = list(map(str, rows[0]._fields)) + elif len(rows) > 0 and hasattr(rows[0], "keys") and hasattr(rows[0], "values"): + # dict-like object + uniq_keys = set() # implements hashed lookup + keys = [] # storage for set + if headers == "firstrow": + firstdict = rows[0] if len(rows) > 0 else {} + keys.extend(firstdict.keys()) + uniq_keys.update(keys) + rows = rows[1:] + for row in rows: + for k in row.keys(): + # Save unique items in input order + if k not in uniq_keys: + keys.append(k) + uniq_keys.add(k) + if headers == "keys": + headers = keys + elif isinstance(headers, dict): + # a dict of headers for a list of dicts + headers = [headers.get(k, k) for k in keys] + headers = list(map(str, headers)) + elif headers == "firstrow": + if len(rows) > 0: + headers = [firstdict.get(k, k) for k in keys] + headers = list(map(str, headers)) + else: + headers = [] + elif headers: + raise ValueError( + "headers for a list of dicts is not a dict or a keyword" + ) + rows = [[row.get(k) for k in keys] for row in rows] + + elif ( + headers == "keys" + and hasattr(tabular_data, "description") + and hasattr(tabular_data, "fetchone") + and hasattr(tabular_data, "rowcount") + ): + # Python Database API cursor object (PEP 0249) + # print tabulate(cursor, headers='keys') + headers = [column[0] for column in tabular_data.description] + + elif ( + dataclasses is not None + and len(rows) > 0 + and dataclasses.is_dataclass(rows[0]) + ): + # Python's dataclass + field_names = [field.name for field in dataclasses.fields(rows[0])] + if headers == "keys": + headers = field_names + rows = [[getattr(row, f) for f in field_names] for row in rows] + + elif headers == "keys" and len(rows) > 0: + # keys are column indices + headers = list(map(str, range(len(rows[0])))) + + # take headers from the first row if necessary + if headers == "firstrow" and len(rows) > 0: + if index is not None: + headers = [index[0]] + list(rows[0]) + index = index[1:] + else: + headers = rows[0] + headers = list(map(str, headers)) # headers should be strings + rows = rows[1:] + elif headers == "firstrow": + headers = [] + + headers = list(map(str, headers)) + # rows = list(map(list, rows)) + rows = list(map(lambda r: r if _is_separating_line(r) else list(r), rows)) + + # add or remove an index column + showindex_is_a_str = type(showindex) in [str, bytes] + if showindex == "never" or (not _bool(showindex) and not showindex_is_a_str): + pass + + # pad with empty headers for initial columns if necessary + headers_pad = 0 + if headers and len(rows) > 0: + headers_pad = max(0, len(rows[0]) - len(headers)) + headers = [""] * headers_pad + headers + + return rows, headers, headers_pad + +def _to_str(s, encoding="utf8", errors="ignore"): + if isinstance(s, bytes): + return s.decode(encoding=encoding, errors=errors) + return str(s) + + +def tabulate( + tabular_data, + headers=(), + tablefmt="simple", + floatfmt=_DEFAULT_FLOATFMT, + intfmt=_DEFAULT_INTFMT, + numalign=_DEFAULT_ALIGN, + stralign=_DEFAULT_ALIGN, + missingval=_DEFAULT_MISSINGVAL, + showindex="default", + disable_numparse=False, + colglobalalign=None, + colalign=None, + preserve_whitespace=False, + maxcolwidths=None, + headersglobalalign=None, + headersalign=None, + rowalign=None, + maxheadercolwidths=None, +): + if tabular_data is None: + tabular_data = [] + + list_of_lists, headers, headers_pad = _normalize_tabular_data( + tabular_data, headers, showindex=showindex + ) + list_of_lists, separating_lines = _remove_separating_lines(list_of_lists) + + # PrettyTable formatting does not use any extra padding. + # Numbers are not parsed and are treated the same as strings for alignment. + # Check if pretty is the format being used and override the defaults so it + # does not impact other formats. + min_padding = MIN_PADDING + if tablefmt == "pretty": + min_padding = 0 + disable_numparse = True + numalign = "center" if numalign == _DEFAULT_ALIGN else numalign + stralign = "center" if stralign == _DEFAULT_ALIGN else stralign + else: + numalign = "decimal" if numalign == _DEFAULT_ALIGN else numalign + stralign = "left" if stralign == _DEFAULT_ALIGN else stralign + + # 'colon_grid' uses colons in the line beneath the header to represent a column's + # alignment instead of literally aligning the text differently. Hence, + # left alignment of the data in the text output is enforced. + if tablefmt == "colon_grid": + colglobalalign = "left" + headersglobalalign = "left" + + # optimization: look for ANSI control codes once, + # enable smart width functions only if a control code is found + # + # convert the headers and rows into a single, tab-delimited string ensuring + # that any bytestrings are decoded safely (i.e. errors ignored) + plain_text = "\t".join( + chain( + # headers + map(_to_str, headers), + # rows: chain the rows together into a single iterable after mapping + # the bytestring conversino to each cell value + chain.from_iterable(map(_to_str, row) for row in list_of_lists), + ) + ) + + has_invisible = _ansi_codes.search(plain_text) is not None + + enable_widechars = wcwidth is not None and WIDE_CHARS_MODE + if ( + not isinstance(tablefmt, TableFormat) + and tablefmt in multiline_formats + and _is_multiline(plain_text) + ): + tablefmt = multiline_formats.get(tablefmt, tablefmt) + is_multiline = True + else: + is_multiline = False + width_fn = _choose_width_fn(has_invisible, enable_widechars, is_multiline) + + # format rows and columns, convert numeric values to strings + cols = list(izip_longest(*list_of_lists)) + numparses = _expand_numparse(disable_numparse, len(cols)) + coltypes = [_column_type(col, numparse=np) for col, np in zip(cols, numparses)] + if isinstance(floatfmt, str): # old version + float_formats = len(cols) * [ + floatfmt + ] # just duplicate the string to use in each column + else: # if floatfmt is list, tuple etc we have one per column + float_formats = list(floatfmt) + if len(float_formats) < len(cols): + float_formats.extend((len(cols) - len(float_formats)) * [_DEFAULT_FLOATFMT]) + if isinstance(intfmt, str): # old version + int_formats = len(cols) * [ + intfmt + ] # just duplicate the string to use in each column + else: # if intfmt is list, tuple etc we have one per column + int_formats = list(intfmt) + if len(int_formats) < len(cols): + int_formats.extend((len(cols) - len(int_formats)) * [_DEFAULT_INTFMT]) + if isinstance(missingval, str): + missing_vals = len(cols) * [missingval] + else: + missing_vals = list(missingval) + if len(missing_vals) < len(cols): + missing_vals.extend((len(cols) - len(missing_vals)) * [_DEFAULT_MISSINGVAL]) + cols = [ + [_format(v, ct, fl_fmt, int_fmt, miss_v, has_invisible) for v in c] + for c, ct, fl_fmt, int_fmt, miss_v in zip( + cols, coltypes, float_formats, int_formats, missing_vals + ) + ] + + # align columns + # first set global alignment + if colglobalalign is not None: # if global alignment provided + aligns = [colglobalalign] * len(cols) + else: # default + aligns = [numalign if ct in [int, float] else stralign for ct in coltypes] + # then specific alignments + if colalign is not None: + assert isinstance(colalign, Iterable) + if isinstance(colalign, str): + warnings.warn( + f"As a string, `colalign` is interpreted as {[c for c in colalign]}. " + f'Did you mean `colglobalalign = "{colalign}"` or `colalign = ("{colalign}",)`?', + stacklevel=2, + ) + for idx, align in enumerate(colalign): + if not idx < len(aligns): + break + elif align != "global": + aligns[idx] = align + minwidths = ( + [width_fn(h) + min_padding for h in headers] if headers else [0] * len(cols) + ) + aligns_copy = aligns.copy() + # Reset alignments in copy of alignments list to "left" for 'colon_grid' format, + # which enforces left alignment in the text output of the data. + if tablefmt == "colon_grid": + aligns_copy = ["left"] * len(cols) + cols = [ + _align_column( + c, + a, + minw, + has_invisible, + enable_widechars, + is_multiline, + preserve_whitespace, + ) + for c, a, minw in zip(cols, aligns_copy, minwidths) + ] + + aligns_headers = None + if headers: + # align headers and add headers + t_cols = cols or [[""]] * len(headers) + # first set global alignment + if headersglobalalign is not None: # if global alignment provided + aligns_headers = [headersglobalalign] * len(t_cols) + else: # default + aligns_headers = aligns or [stralign] * len(headers) + # then specific header alignments + if headersalign is not None: + assert isinstance(headersalign, Iterable) + if isinstance(headersalign, str): + warnings.warn( + f"As a string, `headersalign` is interpreted as {[c for c in headersalign]}. " + f'Did you mean `headersglobalalign = "{headersalign}"` ' + f'or `headersalign = ("{headersalign}",)`?', + stacklevel=2, + ) + for idx, align in enumerate(headersalign): + hidx = headers_pad + idx + if not hidx < len(aligns_headers): + break + elif align == "same" and hidx < len(aligns): # same as column align + aligns_headers[hidx] = aligns[hidx] + elif align != "global": + aligns_headers[hidx] = align + minwidths = [ + max(minw, max(width_fn(cl) for cl in c)) + for minw, c in zip(minwidths, t_cols) + ] + headers = [ + _align_header(h, a, minw, width_fn(h), is_multiline, width_fn) + for h, a, minw in zip(headers, aligns_headers, minwidths) + ] + rows = list(zip(*cols)) + else: + minwidths = [max(width_fn(cl) for cl in c) for c in cols] + rows = list(zip(*cols)) + + if not isinstance(tablefmt, TableFormat): + tablefmt = _table_formats.get(tablefmt, _table_formats["simple"]) + + ra_default = rowalign if isinstance(rowalign, str) else None + rowaligns = _expand_iterable(rowalign, len(rows), ra_default) + return _format_table( + tablefmt, + headers, + aligns_headers, + rows, + minwidths, + aligns, + is_multiline, + rowaligns=rowaligns, + ) + + +def _expand_numparse(disable_numparse, column_count): + if isinstance(disable_numparse, Iterable): + numparses = [True] * column_count + for index in disable_numparse: + numparses[index] = False + return numparses + else: + return [not disable_numparse] * column_count + + +def _expand_iterable(original, num_desired, default): + if isinstance(original, Iterable) and not isinstance(original, str): + return original + [default] * (num_desired - len(original)) + else: + return [default] * num_desired + + +def _pad_row(cells, padding): + if cells: + if cells == SEPARATING_LINE: + return SEPARATING_LINE + pad = " " * padding + padded_cells = [pad + cell + pad for cell in cells] + return padded_cells + else: + return cells + + +def _build_simple_row(padded_cells, rowfmt): + begin, sep, end = rowfmt + return (begin + sep.join(padded_cells) + end).rstrip() + + +def _build_row(padded_cells, colwidths, colaligns, rowfmt): + if not rowfmt: + return None + if hasattr(rowfmt, "__call__"): + return rowfmt(padded_cells, colwidths, colaligns) + else: + return _build_simple_row(padded_cells, rowfmt) + +def _append_basic_row(lines, padded_cells, colwidths, colaligns, rowfmt, rowalign=None): + # NOTE: rowalign is ignored and exists for api compatibility with _append_multiline_row + lines.append(_build_row(padded_cells, colwidths, colaligns, rowfmt)) + return lines + +def _build_line(colwidths, colaligns, linefmt): + "Return a string which represents a horizontal line." + if not linefmt: + return None + if hasattr(linefmt, "__call__"): + return linefmt(colwidths, colaligns) + else: + begin, fill, sep, end = linefmt + cells = [fill * w for w in colwidths] + return _build_simple_row(cells, (begin, sep, end)) + + +def _append_line(lines, colwidths, colaligns, linefmt): + lines.append(_build_line(colwidths, colaligns, linefmt)) + return lines + +def _format_table( + fmt, headers, headersaligns, rows, colwidths, colaligns, is_multiline, rowaligns +): + lines = [] + hidden = fmt.with_header_hide if (headers and fmt.with_header_hide) else [] + pad = fmt.padding + headerrow = fmt.headerrow + + padded_widths = [(w + 2 * pad) for w in colwidths] + pad_row = _pad_row + append_row = _append_basic_row + + padded_headers = pad_row(headers, pad) + + if fmt.lineabove and "lineabove" not in hidden: + _append_line(lines, padded_widths, colaligns, fmt.lineabove) + + if padded_headers: + append_row(lines, padded_headers, padded_widths, headersaligns, headerrow) + if fmt.linebelowheader and "linebelowheader" not in hidden: + _append_line(lines, padded_widths, colaligns, fmt.linebelowheader) + + if rows and fmt.linebetweenrows and "linebetweenrows" not in hidden: + # initial rows with a line below + for row, ralign in zip(rows[:-1], rowaligns): + if row != SEPARATING_LINE: + append_row( + lines, + pad_row(row, pad), + padded_widths, + colaligns, + fmt.datarow, + rowalign=ralign, + ) + _append_line(lines, padded_widths, colaligns, fmt.linebetweenrows) + # the last row without a line below + append_row( + lines, + pad_row(rows[-1], pad), + padded_widths, + colaligns, + fmt.datarow, + rowalign=rowaligns[-1], + ) + else: + separating_line = ( + fmt.linebetweenrows + or fmt.linebelowheader + or fmt.linebelow + or fmt.lineabove + or Line("", "", "", "") + ) + for row in rows: + # test to see if either the 1st column or the 2nd column (account for showindex) has + # the SEPARATING_LINE flag + if _is_separating_line(row): + _append_line(lines, padded_widths, colaligns, separating_line) + else: + append_row( + lines, pad_row(row, pad), padded_widths, colaligns, fmt.datarow + ) + + if fmt.linebelow and "linebelow" not in hidden: + _append_line(lines, padded_widths, colaligns, fmt.linebelow) + + if headers or rows: + output = "\n".join(lines) + return output + else: # a completely empty table + return "" \ No newline at end of file diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 1366fcc0b..a00834cdd 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -222,6 +222,7 @@ class FunctionParent: class OriginalCodeBaseline(BaseModel): behavioral_test_results: TestResults benchmarking_test_results: TestResults + line_profile_results: dict runtime: int coverage_results: Optional[CoverageData] @@ -314,6 +315,7 @@ class FunctionCoverage: class TestingMode(enum.Enum): BEHAVIOR = "behavior" PERFORMANCE = "performance" + LINE_PROFILE = "line_profile" class VerificationType(str, Enum): diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 36d6c6f76..1fc86a5a6 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -7,7 +7,7 @@ import subprocess import time import uuid -from collections import defaultdict +from collections import defaultdict, deque from pathlib import Path from typing import TYPE_CHECKING @@ -37,6 +37,7 @@ ) from codeflash.code_utils.formatter import format_code, sort_imports from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test +from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.code_utils.remove_generated_tests import remove_functions_from_generated_tests from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast from codeflash.code_utils.time_utils import humanize_runtime @@ -66,8 +67,9 @@ from codeflash.verification.concolic_testing import generate_concolic_tests from codeflash.verification.equivalence import compare_test_results from codeflash.verification.instrument_codeflash_capture import instrument_codeflash_capture +from codeflash.verification.parse_line_profile_test_output import parse_line_profile_results from codeflash.verification.parse_test_output import parse_test_results -from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests +from codeflash.verification.test_runner import run_behavioral_tests, run_benchmarking_tests, run_line_profile_tests from codeflash.verification.verification_utils import get_test_file_path from codeflash.verification.verifier import generate_tests @@ -230,7 +232,8 @@ def optimize_function(self) -> Result[BestOptimization, str]: ): cleanup_paths(paths_to_cleanup) return Failure("The threshold for test coverage was not met.") - + # request for new optimizations but don't block execution, check for completion later + # adding to control and experiment set but with same traceid best_optimization = None for _u, candidates in enumerate([optimizations_set.control, optimizations_set.experiment]): @@ -356,94 +359,123 @@ def determine_best_candidate( f"{self.function_to_optimize.qualified_name}…" ) console.rule() - try: - for candidate_index, candidate in enumerate(candidates, start=1): - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) - get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) - logger.info(f"Optimization candidate {candidate_index}/{len(candidates)}:") - code_print(candidate.source_code) - try: - did_update = self.replace_function_and_helpers_with_optimized_code( - code_context=code_context, optimized_code=candidate.source_code - ) - if not did_update: - logger.warning( - "No functions were replaced in the optimized code. Skipping optimization candidate." + candidates = deque(candidates) + # Start a new thread for AI service request, start loop in main thread + # check if aiservice request is complete, when it is complete, append result to the candidates list + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + future_line_profile_results = executor.submit( + self.aiservice_client.optimize_python_code_line_profiler, + source_code=code_context.read_writable_code, + dependency_code=code_context.read_only_context_code, + trace_id=self.function_trace_id, + line_profiler_results=original_code_baseline.line_profile_results["str_out"], + num_candidates=10, + experiment_metadata=None, + ) + try: + candidate_index = 0 + done = False + original_len = len(candidates) + while candidates: + # for candidate_index, candidate in enumerate(candidates, start=1): + done = True if future_line_profile_results is None else future_line_profile_results.done() + if done and (future_line_profile_results is not None): + line_profile_results = future_line_profile_results.result() + candidates.extend(line_profile_results) + original_len+= len(candidates) + logger.info(f"Added results from line profiler to candidates, total candidates now: {original_len}") + future_line_profile_results = None + candidate_index += 1 + candidate = candidates.popleft() + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.bin")).unlink(missing_ok=True) + get_run_tmp_file(Path(f"test_return_values_{candidate_index}.sqlite")).unlink(missing_ok=True) + logger.info(f"Optimization candidate {candidate_index}/{original_len}:") + code_print(candidate.source_code) + try: + did_update = self.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=candidate.source_code + ) + if not did_update: + logger.warning( + "No functions were replaced in the optimized code. Skipping optimization candidate." + ) + console.rule() + continue + except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: + logger.error(e) + self.write_code_and_helpers( + self.function_to_optimize_source_code, + original_helper_code, + self.function_to_optimize.file_path, ) - console.rule() continue - except (ValueError, SyntaxError, cst.ParserSyntaxError, AttributeError) as e: - logger.error(e) - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - continue - - # Instrument codeflash capture - run_results = self.run_optimized_candidate( - optimization_candidate_index=candidate_index, - baseline_results=original_code_baseline, - original_helper_code=original_helper_code, - file_path_to_helper_classes=file_path_to_helper_classes, - ) - console.rule() - if not is_successful(run_results): - optimized_runtimes[candidate.optimization_id] = None - is_correct[candidate.optimization_id] = False - speedup_ratios[candidate.optimization_id] = None - else: - candidate_result: OptimizedCandidateResult = run_results.unwrap() - best_test_runtime = candidate_result.best_test_runtime - optimized_runtimes[candidate.optimization_id] = best_test_runtime - is_correct[candidate.optimization_id] = True - perf_gain = performance_gain( - original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime + # Instrument codeflash capture + run_results = self.run_optimized_candidate( + optimization_candidate_index=candidate_index, + baseline_results=original_code_baseline, + original_helper_code=original_helper_code, + file_path_to_helper_classes=file_path_to_helper_classes, ) - speedup_ratios[candidate.optimization_id] = perf_gain - - tree = Tree(f"Candidate #{candidate_index} - Runtime Information") - if speedup_critic( - candidate_result, original_code_baseline.runtime, best_runtime_until_now - ) and quantity_of_tests_critic(candidate_result): - tree.add("This candidate is faster than the previous best candidate. 🚀") - tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") - tree.add( - f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " - f"(measured over {candidate_result.max_loop_count} " - f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" - ) - tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") - tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X") - - best_optimization = BestOptimization( - candidate=candidate, - helper_functions=code_context.helper_functions, - runtime=best_test_runtime, - winning_behavioral_test_results=candidate_result.behavior_test_results, - winning_benchmarking_test_results=candidate_result.benchmarking_test_results, - ) - best_runtime_until_now = best_test_runtime + console.rule() + + if not is_successful(run_results): + optimized_runtimes[candidate.optimization_id] = None + is_correct[candidate.optimization_id] = False + speedup_ratios[candidate.optimization_id] = None else: - tree.add( - f"Summed runtime: {humanize_runtime(best_test_runtime)} " - f"(measured over {candidate_result.max_loop_count} " - f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" + candidate_result: OptimizedCandidateResult = run_results.unwrap() + best_test_runtime = candidate_result.best_test_runtime + optimized_runtimes[candidate.optimization_id] = best_test_runtime + is_correct[candidate.optimization_id] = True + perf_gain = performance_gain( + original_runtime_ns=original_code_baseline.runtime, optimized_runtime_ns=best_test_runtime ) - tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") - tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") - console.print(tree) - console.rule() + speedup_ratios[candidate.optimization_id] = perf_gain + + tree = Tree(f"Candidate #{candidate_index} - Runtime Information") + if speedup_critic( + candidate_result, original_code_baseline.runtime, best_runtime_until_now + ) and quantity_of_tests_critic(candidate_result): + tree.add("This candidate is faster than the previous best candidate. 🚀") + tree.add(f"Original summed runtime: {humanize_runtime(original_code_baseline.runtime)}") + tree.add( + f"Best summed runtime: {humanize_runtime(candidate_result.best_test_runtime)} " + f"(measured over {candidate_result.max_loop_count} " + f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") + tree.add(f"Speedup ratio: {perf_gain + 1:.1f}X") + + best_optimization = BestOptimization( + candidate=candidate, + helper_functions=code_context.helper_functions, + runtime=best_test_runtime, + winning_behavioral_test_results=candidate_result.behavior_test_results, + winning_benchmarking_test_results=candidate_result.benchmarking_test_results, + ) + best_runtime_until_now = best_test_runtime + else: + tree.add( + f"Summed runtime: {humanize_runtime(best_test_runtime)} " + f"(measured over {candidate_result.max_loop_count} " + f"loop{'s' if candidate_result.max_loop_count > 1 else ''})" + ) + tree.add(f"Speedup percentage: {perf_gain * 100:.1f}%") + tree.add(f"Speedup ratio: {perf_gain + 1:.3f}X") + console.print(tree) + console.rule() + + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + except KeyboardInterrupt as e: self.write_code_and_helpers( self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path ) - except KeyboardInterrupt as e: - self.write_code_and_helpers( - self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path - ) - logger.exception(f"Optimization interrupted: {e}") - raise + logger.exception(f"Optimization interrupted: {e}") + raise self.aiservice_client.log_results( function_trace_id=self.function_trace_id, @@ -758,6 +790,7 @@ def establish_original_code_baseline( original_helper_code: dict[Path, str], file_path_to_helper_classes: dict[Path, set[str]], ) -> Result[tuple[OriginalCodeBaseline, list[str]], str]: + line_profile_results = {"timings": {}, "unit": 0, "str_out": ""} # For the original function - run the tests and get the runtime, plus coverage with progress_bar(f"Establishing original code baseline for {self.function_to_optimize.function_name}"): assert (test_framework := self.args.test_framework) in ["pytest", "unittest"] @@ -801,6 +834,28 @@ def establish_original_code_baseline( if not coverage_critic(coverage_results, self.args.test_framework): return Failure("The threshold for test coverage was not met.") if test_framework == "pytest": + try: + line_profiler_output_file = add_decorator_imports(self.function_to_optimize, code_context) + line_profile_results, _ = self.run_and_parse_tests( + testing_type=TestingMode.LINE_PROFILE, + test_env=test_env, + test_files=self.test_files, + optimization_iteration=0, + testing_time=TOTAL_LOOPING_TIME, + enable_coverage=False, + code_context=code_context, + line_profiler_output_file=line_profiler_output_file, + ) + finally: + # Remove codeflash capture + self.write_code_and_helpers( + self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path + ) + if line_profile_results["str_out"] == "": + logger.warning( + f"Couldn't run line profiler for original function {self.function_to_optimize.function_name}" + ) + console.rule() benchmarking_results, _ = self.run_and_parse_tests( testing_type=TestingMode.PERFORMANCE, test_env=test_env, @@ -872,6 +927,7 @@ def establish_original_code_baseline( benchmarking_test_results=benchmarking_results, runtime=total_timing, coverage_results=coverage_results, + line_profile_results=line_profile_results, ), functions_to_remove, ) @@ -1007,7 +1063,8 @@ def run_and_parse_tests( pytest_max_loops: int = 100_000, code_context: CodeOptimizationContext | None = None, unittest_loop_index: int | None = None, - ) -> tuple[TestResults, CoverageData | None]: + line_profiler_output_file: Path | None = None, + ) -> tuple[TestResults | dict, CoverageData | None]: coverage_database_file = None coverage_config_file = None try: @@ -1021,6 +1078,19 @@ def run_and_parse_tests( verbose=True, enable_coverage=enable_coverage, ) + elif testing_type == TestingMode.LINE_PROFILE: + result_file_path, run_result = run_line_profile_tests( + test_files, + cwd=self.project_root, + test_env=test_env, + pytest_cmd=self.test_cfg.pytest_cmd, + pytest_timeout=INDIVIDUAL_TESTCASE_TIMEOUT, + pytest_target_runtime_seconds=testing_time, + pytest_min_loops=1, + pytest_max_loops=1, + test_framework=self.test_cfg.test_framework, + line_profiler_output_file=line_profiler_output_file, + ) elif testing_type == TestingMode.PERFORMANCE: result_file_path, run_result = run_benchmarking_tests( test_files, @@ -1048,20 +1118,22 @@ def run_and_parse_tests( f"stdout: {run_result.stdout}\n" f"stderr: {run_result.stderr}\n" ) - - results, coverage_results = parse_test_results( - test_xml_path=result_file_path, - test_files=test_files, - test_config=self.test_cfg, - optimization_iteration=optimization_iteration, - run_result=run_result, - unittest_loop_index=unittest_loop_index, - function_name=self.function_to_optimize.function_name, - source_file=self.function_to_optimize.file_path, - code_context=code_context, - coverage_database_file=coverage_database_file, - coverage_config_file=coverage_config_file, - ) + if testing_type in [TestingMode.BEHAVIOR, TestingMode.PERFORMANCE]: + results, coverage_results = parse_test_results( + test_xml_path=result_file_path, + test_files=test_files, + test_config=self.test_cfg, + optimization_iteration=optimization_iteration, + run_result=run_result, + unittest_loop_index=unittest_loop_index, + function_name=self.function_to_optimize.function_name, + source_file=self.function_to_optimize.file_path, + code_context=code_context, + coverage_database_file=coverage_database_file, + coverage_config_file=coverage_config_file, + ) + else: + results, coverage_results = parse_line_profile_results(line_profiler_output_file=line_profiler_output_file) return results, coverage_results def generate_and_instrument_tests( diff --git a/codeflash/verification/parse_line_profile_test_output.py b/codeflash/verification/parse_line_profile_test_output.py new file mode 100644 index 000000000..0536d4825 --- /dev/null +++ b/codeflash/verification/parse_line_profile_test_output.py @@ -0,0 +1,88 @@ +"""Adapted from line_profiler (https://github.com/pyutils/line_profiler) written by Enthought, Inc. (BSD License)""" +import linecache +import inspect +from codeflash.code_utils.tabulate import tabulate +import os +import dill as pickle +from pathlib import Path +from typing import Optional + +def show_func(filename, start_lineno, func_name, timings, unit): + total_hits = sum(t[1] for t in timings) + total_time = sum(t[2] for t in timings) + out_table = "" + table_rows = [] + if total_hits == 0: + return '' + scalar = 1 + if os.path.exists(filename): + out_table+=f'## Function: {func_name}\n' + # Clear the cache to ensure that we get up-to-date results. + linecache.clearcache() + all_lines = linecache.getlines(filename) + sublines = inspect.getblock(all_lines[start_lineno - 1:]) + out_table+='## Total time: %g s\n' % (total_time * unit) + # Define minimum column sizes so text fits and usually looks consistent + default_column_sizes = { + 'hits': 9, + 'time': 12, + 'perhit': 8, + 'percent': 8, + } + display = {} + # Loop over each line to determine better column formatting. + # Fallback to scientific notation if columns are larger than a threshold. + for lineno, nhits, time in timings: + if total_time == 0: # Happens rarely on empty function + percent = '' + else: + percent = '%5.1f' % (100 * time / total_time) + + time_disp = '%5.1f' % (time * scalar) + if len(time_disp) > default_column_sizes['time']: + time_disp = '%5.1g' % (time * scalar) + perhit_disp = '%5.1f' % (float(time) * scalar / nhits) + if len(perhit_disp) > default_column_sizes['perhit']: + perhit_disp = '%5.1g' % (float(time) * scalar / nhits) + nhits_disp = "%d" % nhits + if len(nhits_disp) > default_column_sizes['hits']: + nhits_disp = '%g' % nhits + display[lineno] = (nhits_disp, time_disp, perhit_disp, percent) + linenos = range(start_lineno, start_lineno + len(sublines)) + empty = ('', '', '', '') + table_cols = ('Hits', 'Time', 'Per Hit', '% Time', 'Line Contents') + for lineno, line in zip(linenos, sublines): + nhits, time, per_hit, percent = display.get(lineno, empty) + line_ = line.rstrip('\n').rstrip('\r') + if 'def' in line_ or nhits!='': + table_rows.append((nhits, time, per_hit, percent, line_)) + pass + out_table+= tabulate(headers=table_cols,tabular_data=table_rows,tablefmt="pipe",colglobalalign=None, preserve_whitespace=True) + out_table+='\n' + return out_table + +def show_text(stats: dict) -> str: + """ Show text for the given timings. + """ + out_table = "" + out_table+='# Timer unit: %g s\n' % stats['unit'] + stats_order = sorted(stats['timings'].items()) + # Show detailed per-line information for each function. + for (fn, lineno, name), timings in stats_order: + table_md =show_func(fn, lineno, name, stats['timings'][fn, lineno, name], stats['unit']) + out_table+=table_md + return out_table + +def parse_line_profile_results(line_profiler_output_file: Optional[Path]) -> dict: + line_profiler_output_file = line_profiler_output_file.with_suffix(".lprof") + stats_dict = {} + if not line_profiler_output_file.exists(): + return {'timings':{},'unit':0, 'str_out':''}, None + else: + with open(line_profiler_output_file,'rb') as f: + stats = pickle.load(f) + stats_dict['timings'] = stats.timings + stats_dict['unit'] = stats.unit + str_out=show_text(stats_dict) + stats_dict['str_out']=str_out + return stats_dict, None \ No newline at end of file diff --git a/codeflash/verification/test_runner.py b/codeflash/verification/test_runner.py index d4b3f15b4..fb53d8652 100644 --- a/codeflash/verification/test_runner.py +++ b/codeflash/verification/test_runner.py @@ -141,6 +141,67 @@ def run_behavioral_tests( coverage_config_file if enable_coverage else None, ) +def run_line_profile_tests( + test_paths: TestFiles, + pytest_cmd: str, + test_env: dict[str, str], + cwd: Path, + test_framework: str, + *, + pytest_target_runtime_seconds: float = TOTAL_LOOPING_TIME, + verbose: bool = False, + pytest_timeout: int | None = None, + pytest_min_loops: int = 5, + pytest_max_loops: int = 100_000, + line_profiler_output_file: Path | None = None, + +) -> tuple[Path, subprocess.CompletedProcess]: + if test_framework == "pytest": + pytest_cmd_list = ( + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX) + if pytest_cmd == "pytest" + else shlex.split(pytest_cmd) + ) + test_files: list[str] = [] + for file in test_paths.test_files: + if file.test_type in [TestType.REPLAY_TEST, TestType.EXISTING_UNIT_TEST] and file.tests_in_file: + test_files.extend( + [ + str(file.benchmarking_file_path) + + "::" + + (test.test_class + "::" if test.test_class else "") + + (test.test_function.split("[", 1)[0] if "[" in test.test_function else test.test_function) + for test in file.tests_in_file + ] + ) + else: + test_files.append(str(file.benchmarking_file_path)) + test_files = list(set(test_files)) # remove multiple calls in the same test function + pytest_args = [ + "--capture=tee-sys", + f"--timeout={pytest_timeout}", + "-q", + "--codeflash_loops_scope=session", + f"--codeflash_min_loops={pytest_min_loops}", + f"--codeflash_max_loops={pytest_max_loops}", + f"--codeflash_seconds={pytest_target_runtime_seconds}", + ] + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] + pytest_test_env = test_env.copy() + pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" + blocklist_args = [f"-p no:{plugin}" for plugin in BENCHMARKING_BLOCKLISTED_PLUGINS] + pytest_test_env["LINE_PROFILE"]="1" + results = execute_test_subprocess( + pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files, + cwd=cwd, + env=pytest_test_env, + timeout=600, # TODO: Make this dynamic + ) + else: + msg = f"Unsupported test framework: {test_framework}" + raise ValueError(msg) + return line_profiler_output_file, results def run_benchmarking_tests( test_paths: TestFiles, diff --git a/pyproject.toml b/pyproject.toml index e8aa01d75..a181fac2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -92,6 +92,7 @@ rich = ">=13.8.1" lxml = ">=5.3.0" crosshair-tool = ">=0.0.78" coverage = ">=7.6.4" +line_profiler=">=4.2.0" #this is the minimum version which supports python 3.13 [tool.poetry.group.dev] optional = true diff --git a/tests/test_instrument_line_profiler.py b/tests/test_instrument_line_profiler.py new file mode 100644 index 000000000..1161bd7cd --- /dev/null +++ b/tests/test_instrument_line_profiler.py @@ -0,0 +1,291 @@ +import os +from pathlib import Path +from tempfile import TemporaryDirectory + +from codeflash.code_utils.line_profile_utils import add_decorator_imports +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.models.models import CodeOptimizationContext +from codeflash.optimization.function_optimizer import FunctionOptimizer +from codeflash.verification.verification_utils import TestConfig + + +def test_add_decorator_imports_helper_in_class(): + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_classmethod.py").resolve() + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + run_cwd = Path(__file__).parent.parent.resolve() + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + os.chdir(run_cwd) + #func_optimizer = pass + try: + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + expected_code_main = f"""from line_profiler import profile as codeflash_line_profile +codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}') + +from code_to_optimize.bubble_sort_in_class import BubbleSortClass + + +@codeflash_line_profile +def sort_classmethod(x): + y = BubbleSortClass() + return y.sorter(x) +""" + expected_code_helper = """from line_profiler import profile as codeflash_line_profile + + +def hi(): + pass + + +class BubbleSortClass: + def __init__(self): + pass + + @codeflash_line_profile + def sorter(self, arr): + n = len(arr) + for i in range(n): + for j in range(0, n - i - 1): + if arr[j] > arr[j + 1]: + arr[j], arr[j + 1] = arr[j + 1], arr[j] + return arr + + def helper(self, arr, j): + return arr[j] > arr[j + 1] +""" + assert code_path.read_text("utf-8") == expected_code_main + assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper + finally: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) + +def test_add_decorator_imports_helper_in_nested_class(): + #Need to invert the assert once the helper detection is fixed + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_nested_classmethod.py").resolve() + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + run_cwd = Path(__file__).parent.parent.resolve() + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func = FunctionToOptimize(function_name="sort_classmethod", parents=[], file_path=code_path) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + os.chdir(run_cwd) + #func_optimizer = pass + try: + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + expected_code_main = f"""from line_profiler import profile as codeflash_line_profile +codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}') + +from code_to_optimize.bubble_sort_in_nested_class import WrapperClass + + +@codeflash_line_profile +def sort_classmethod(x): + y = WrapperClass.BubbleSortClass() + return y.sorter(x) +""" + assert code_path.read_text("utf-8") == expected_code_main + assert code_context.helper_functions.__len__() == 0 + finally: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) + +def test_add_decorator_imports_nodeps(): + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort.py").resolve() + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + run_cwd = Path(__file__).parent.parent.resolve() + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_path) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + os.chdir(run_cwd) + #func_optimizer = pass + try: + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + expected_code_main = f"""from line_profiler import profile as codeflash_line_profile +codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}') + + +@codeflash_line_profile +def sorter(arr): + print("codeflash stdout: Sorting list") + for i in range(len(arr)): + for j in range(len(arr) - 1): + if arr[j] > arr[j + 1]: + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp + print(f"result: {{arr}}") + return arr +""" + assert code_path.read_text("utf-8") == expected_code_main + finally: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) + +def test_add_decorator_imports_helper_outside(): + code_path = (Path(__file__).parent.resolve() / "../code_to_optimize/bubble_sort_deps.py").resolve() + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = (Path(__file__).parent / "..").resolve() + run_cwd = Path(__file__).parent.parent.resolve() + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func = FunctionToOptimize(function_name="sorter_deps", parents=[], file_path=code_path) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + os.chdir(run_cwd) + #func_optimizer = pass + try: + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + expected_code_main = f"""from line_profiler import profile as codeflash_line_profile +codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}') + +from code_to_optimize.bubble_sort_dep1_helper import dep1_comparer +from code_to_optimize.bubble_sort_dep2_swap import dep2_swap + + +@codeflash_line_profile +def sorter_deps(arr): + for i in range(len(arr)): + for j in range(len(arr) - 1): + if dep1_comparer(arr, j): + dep2_swap(arr, j) + return arr + +""" + expected_code_helper1 = """from line_profiler import profile as codeflash_line_profile + + +@codeflash_line_profile +def dep1_comparer(arr, j: int) -> bool: + return arr[j] > arr[j + 1] +""" + expected_code_helper2="""from line_profiler import profile as codeflash_line_profile + + +@codeflash_line_profile +def dep2_swap(arr, j): + temp = arr[j] + arr[j] = arr[j + 1] + arr[j + 1] = temp +""" + assert code_path.read_text("utf-8") == expected_code_main + assert code_context.helper_functions[0].file_path.read_text("utf-8") == expected_code_helper1 + assert code_context.helper_functions[1].file_path.read_text("utf-8") == expected_code_helper2 + finally: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) + +def test_add_decorator_imports_helper_in_dunder_class(): + code_str = """def sorter(arr): + ans = helper(arr) + return ans +class helper: + def __init__(self, arr): + return arr.sort()""" + code_path = TemporaryDirectory() + code_write_path = Path(code_path.name) / "dunder_class.py" + code_write_path.write_text(code_str,"utf-8") + tests_root = Path(__file__).parent.resolve() / "../code_to_optimize/tests/pytest/" + project_root_path = Path(code_path.name) + run_cwd = Path(__file__).parent.parent.resolve() + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=project_root_path, + project_root_path=project_root_path, + test_framework="pytest", + pytest_cmd="pytest", + ) + func = FunctionToOptimize(function_name="sorter", parents=[], file_path=code_write_path) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + os.chdir(run_cwd) + #func_optimizer = pass + try: + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + expected_code_main = f"""from line_profiler import profile as codeflash_line_profile +codeflash_line_profile.enable(output_prefix='{line_profiler_output_file}') + + +@codeflash_line_profile +def sorter(arr): + ans = helper(arr) + return ans +class helper: + def __init__(self, arr): + return arr.sort() +""" + assert code_write_path.read_text("utf-8") == expected_code_main + finally: + pass diff --git a/tests/test_instrument_tests.py b/tests/test_instrument_tests.py index 16be1966e..44661912a 100644 --- a/tests/test_instrument_tests.py +++ b/tests/test_instrument_tests.py @@ -12,8 +12,10 @@ FunctionImportedAsVisitor, inject_profiling_into_existing_test, ) +from codeflash.code_utils.line_profile_utils import add_decorator_imports from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.models.models import ( + CodeOptimizationContext, CodePosition, FunctionParent, TestFile, @@ -283,6 +285,7 @@ def test_prepare_image_for_yolo(): def test_perfinjector_bubble_sort_results() -> None: + computed_fn_opt = False code = """from code_to_optimize.bubble_sort import sorter @@ -495,13 +498,40 @@ def test_sort(): codeflash stdout: Sorting list result: [0.0, 1.0, 2.0, 3.0, 4.0, 5.0]""" assert out_str == test_results_perf[1].stdout - + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + computed_fn_opt = True + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + line_profile_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.LINE_PROFILE, + test_env=test_env, + test_files=test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + line_profiler_output_file = line_profiler_output_file + ) + tmp_lpr = list(line_profile_results["timings"].keys()) + assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1]==2 finally: + if computed_fn_opt: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) test_path.unlink(missing_ok=True) test_path_perf.unlink(missing_ok=True) def test_perfinjector_bubble_sort_parametrized_results() -> None: + computed_fn_opt = False code = """from code_to_optimize.bubble_sort import sorter import pytest @@ -721,13 +751,40 @@ def test_sort_parametrized(input, expected_output): ) assert test_results_perf[2].runtime > 0 assert test_results_perf[2].did_pass - + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + computed_fn_opt = True + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + line_profile_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.LINE_PROFILE, + test_env=test_env, + test_files=test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + line_profiler_output_file = line_profiler_output_file + ) + tmp_lpr = list(line_profile_results["timings"].keys()) + assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1]==3 finally: + if computed_fn_opt: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) test_path.unlink(missing_ok=True) test_path_perf.unlink(missing_ok=True) def test_perfinjector_bubble_sort_parametrized_loop_results() -> None: + computed_fn_opt = False code = """from code_to_optimize.bubble_sort import sorter import pytest @@ -1033,13 +1090,41 @@ def test_sort_parametrized_loop(input, expected_output): ) assert test_results[5].runtime > 0 assert test_results[5].did_pass + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + computed_fn_opt = True + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + line_profile_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.LINE_PROFILE, + test_env=test_env, + test_files=test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + line_profiler_output_file = line_profiler_output_file + ) + tmp_lpr = list(line_profile_results["timings"].keys()) + assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1]==6 finally: + if computed_fn_opt: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) test_path.unlink(missing_ok=True) test_path_behavior.unlink(missing_ok=True) test_path_perf.unlink(missing_ok=True) def test_perfinjector_bubble_sort_loop_results() -> None: + computed_fn_opt = False code = """from code_to_optimize.bubble_sort import sorter @@ -1278,7 +1363,34 @@ def test_sort(): ) assert test_results[2].runtime > 0 assert test_results[2].did_pass + ctx_result = func_optimizer.get_code_optimization_context() + code_context: CodeOptimizationContext = ctx_result.unwrap() + original_helper_code: dict[Path, str] = {} + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + computed_fn_opt = True + line_profiler_output_file = add_decorator_imports( + func_optimizer.function_to_optimize, code_context) + line_profile_results, _ = func_optimizer.run_and_parse_tests( + testing_type=TestingMode.LINE_PROFILE, + test_env=test_env, + test_files=test_files, + optimization_iteration=0, + pytest_min_loops=1, + pytest_max_loops=1, + testing_time=0.1, + line_profiler_output_file = line_profiler_output_file + ) + tmp_lpr = list(line_profile_results["timings"].keys()) + assert len(tmp_lpr) == 1 and line_profile_results["timings"][tmp_lpr[0]][0][1]==3 finally: + if computed_fn_opt is True: + func_optimizer.write_code_and_helpers( + func_optimizer.function_to_optimize_source_code, original_helper_code, func_optimizer.function_to_optimize.file_path + ) test_path.unlink(missing_ok=True) test_path_perf.unlink(missing_ok=True) test_path_behavior.unlink(missing_ok=True)