diff --git a/code_to_optimize/code_directories/circular_deps/constants.py b/code_to_optimize/code_directories/circular_deps/constants.py index dc4b0638e..be8fdac15 100644 --- a/code_to_optimize/code_directories/circular_deps/constants.py +++ b/code_to_optimize/code_directories/circular_deps/constants.py @@ -1,8 +1,2 @@ DEFAULT_API_URL = "https://api.galileo.ai/" DEFAULT_APP_URL = "https://app.galileo.ai/" - - -# function_names: GalileoApiClient.get_console_url -# module_abs_path : /home/mohammed/Work/galileo-python/src/galileo/api_client.py -# preexisting_objects: {('GalileoApiClient', ()), ('_set_destination', ()), ('get_console_url', (FunctionParent(name='GalileoApiClient', type='ClassDef'),))} -# project_root_path: /home/mohammed/Work/galileo-python/src diff --git a/code_to_optimize/code_directories/unstructured_example/base.py b/code_to_optimize/code_directories/unstructured_example/base.py new file mode 100644 index 000000000..34b4e9891 --- /dev/null +++ b/code_to_optimize/code_directories/unstructured_example/base.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import base64 +import json +import zlib +from copy import deepcopy +from typing import Any, Iterable +from utils import Point + +from coordinates import PixelSpace +from elements import ( + TYPE_TO_TEXT_ELEMENT_MAP, + CheckBox, + Element, + ElementMetadata, +) + +# ================================================================================================ +# SERIALIZATION/DESERIALIZATION (SERDE) RELATED FUNCTIONS +# ================================================================================================ +# These serde functions will likely relocate to `unstructured.documents.elements` since they are +# so closely related to elements and this staging "brick" is deprecated. +# ================================================================================================ + +# == DESERIALIZERS =============================== + + +def elements_from_base64_gzipped_json(b64_encoded_elements: str) -> list[Element]: + """Restore Base64-encoded gzipped JSON elements to element objects. + + This is used to when deserializing `ElementMetadata.orig_elements` from its compressed form in + JSON and dict forms and perhaps for other purposes. + """ + # -- Base64 str -> gzip-encoded (JSON) bytes -- + decoded_b64_bytes = base64.b64decode(b64_encoded_elements) + # -- undo gzip compression -- + elements_json_bytes = zlib.decompress(decoded_b64_bytes) + # -- JSON (bytes) to JSON (str) -- + elements_json_str = elements_json_bytes.decode("utf-8") + # -- JSON (str) -> dicts -- + element_dicts = json.loads(elements_json_str) + # -- dicts -> elements -- + return elements_from_dicts(element_dicts) + + +def elements_from_dicts(element_dicts: Iterable[dict[str, Any]]) -> list[Element]: + """Convert a list of element-dicts to a list of elements.""" + elements: list[Element] = [] + + for item in element_dicts: + element_id: str = item.get("element_id", None) + metadata = ( + ElementMetadata() + if item.get("metadata") is None + else ElementMetadata.from_dict(item["metadata"]) + ) + + if item.get("type") in TYPE_TO_TEXT_ELEMENT_MAP: + ElementCls = TYPE_TO_TEXT_ELEMENT_MAP[item["type"]] + elements.append(ElementCls(text=item["text"], element_id=element_id, metadata=metadata)) + elif item.get("type") == "CheckBox": + elements.append( + CheckBox(checked=item["checked"], element_id=element_id, metadata=metadata) + ) + + return elements + +def elements_to_base64_gzipped_json(elements: Iterable[Element]) -> str: + """Convert `elements` to Base64-encoded gzipped JSON. + + This is used to when serializing `ElementMetadata.orig_elements` to make it as compact as + possible when transported as JSON, for example in an HTTP response. This compressed form is also + present when elements are in dict form ("element_dicts"). This function is not coupled to that + purpose however and could have other uses. + """ + # -- adjust floating-point precision of coordinates down for a more compact str value -- + precision_adjusted_elements = _fix_metadata_field_precision(elements) + # -- serialize elements as dicts -- + element_dicts = elements_to_dicts(precision_adjusted_elements) + # -- serialize the dicts to JSON (bytes) -- + json_bytes = json.dumps(element_dicts, sort_keys=True).encode("utf-8") + # -- compress the JSON bytes with gzip compression -- + deflated_bytes = zlib.compress(json_bytes) + # -- base64-encode those bytes so they can be serialized as a JSON string value -- + b64_deflated_bytes = base64.b64encode(deflated_bytes) + # -- convert to a string suitable for serializing in JSON -- + return b64_deflated_bytes.decode("utf-8") + + +def elements_to_dicts(elements: Iterable[Element]) -> list[dict[str, Any]]: + """Convert document elements to element-dicts.""" + return [e.to_dict() for e in elements] + + +def _fix_metadata_field_precision(elements: Iterable[Element]) -> list[Element]: + out_elements: list[Element] = [] + for element in elements: + el = deepcopy(element) + if el.metadata.coordinates: + precision = 1 if isinstance(el.metadata.coordinates.system, PixelSpace) else 2 + points = el.metadata.coordinates.points + assert points is not None + rounded_points: list[Point] = [] + for point in points: + x, y = point + rounded_point = (round(x, precision), round(y, precision)) + rounded_points.append(rounded_point) + el.metadata.coordinates.points = tuple(rounded_points) + + if el.metadata.detection_class_prob: + el.metadata.detection_class_prob = round(el.metadata.detection_class_prob, 5) + + out_elements.append(el) + + return out_elements \ No newline at end of file diff --git a/code_to_optimize/code_directories/unstructured_example/coordinates.py b/code_to_optimize/code_directories/unstructured_example/coordinates.py new file mode 100644 index 000000000..16a65afd3 --- /dev/null +++ b/code_to_optimize/code_directories/unstructured_example/coordinates.py @@ -0,0 +1,113 @@ +from __future__ import annotations + +from enum import Enum +from typing import Any, Dict, Sequence, Tuple, Union + + +class Orientation(Enum): + SCREEN = (1, -1) # Origin in top left, y increases in the down direction + CARTESIAN = (1, 1) # Origin in bottom left, y increases in upward direction + + +def convert_coordinate(old_t, old_t_max, new_t_max, t_orientation): + """Convert a coordinate into another system along an axis using a linear transformation""" + return ( + (1 - old_t / old_t_max) * (1 - t_orientation) / 2 + + old_t / old_t_max * (1 + t_orientation) / 2 + ) * new_t_max + + +class CoordinateSystem: + """A finite coordinate plane with given width and height.""" + + orientation: Orientation + + def __init__(self, width: Union[int, float], height: Union[int, float]): + self.width = width + self.height = height + + def __eq__(self, other: object): + if not isinstance(other, CoordinateSystem): + return False + return ( + str(self.__class__.__name__) == str(other.__class__.__name__) + and self.width == other.width + and self.height == other.height + and self.orientation == other.orientation + ) + + def convert_from_relative( + self, + x: Union[float, int], + y: Union[float, int], + ) -> Tuple[Union[float, int], Union[float, int]]: + """Convert to this coordinate system from a relative coordinate system.""" + x_orientation, y_orientation = self.orientation.value + new_x = convert_coordinate(x, 1, self.width, x_orientation) + new_y = convert_coordinate(y, 1, self.height, y_orientation) + return new_x, new_y + + def convert_to_relative( + self, + x: Union[float, int], + y: Union[float, int], + ) -> Tuple[Union[float, int], Union[float, int]]: + """Convert from this coordinate system to a relative coordinate system.""" + x_orientation, y_orientation = self.orientation.value + new_x = convert_coordinate(x, self.width, 1, x_orientation) + new_y = convert_coordinate(y, self.height, 1, y_orientation) + return new_x, new_y + + def convert_coordinates_to_new_system( + self, + new_system: CoordinateSystem, + x: Union[float, int], + y: Union[float, int], + ) -> Tuple[Union[float, int], Union[float, int]]: + """Convert from this coordinate system to another given coordinate system.""" + rel_x, rel_y = self.convert_to_relative(x, y) + return new_system.convert_from_relative(rel_x, rel_y) + + def convert_multiple_coordinates_to_new_system( + self, + new_system: CoordinateSystem, + coordinates: Sequence[Tuple[Union[float, int], Union[float, int]]], + ) -> Tuple[Tuple[Union[float, int], Union[float, int]], ...]: + """Convert (x, y) coordinates from current system to another coordinate system.""" + new_system_coordinates = [] + for x, y in coordinates: + new_system_coordinates.append( + self.convert_coordinates_to_new_system(new_system=new_system, x=x, y=y), + ) + return tuple(new_system_coordinates) + + +class RelativeCoordinateSystem(CoordinateSystem): + """Relative coordinate system where x and y are on a scale from 0 to 1.""" + + orientation = Orientation.CARTESIAN + + def __init__(self): + self.width = 1 + self.height = 1 + + +class PixelSpace(CoordinateSystem): + """Coordinate system representing a pixel space, such as an image. The origin is at the top + left.""" + + orientation = Orientation.SCREEN + + +class PointSpace(CoordinateSystem): + """Coordinate system representing a point space, such as a pdf. The origin is at the bottom + left.""" + + orientation = Orientation.CARTESIAN + + +TYPE_TO_COORDINATE_SYSTEM_MAP: Dict[str, Any] = { + "PixelSpace": PixelSpace, + "PointSpace": PointSpace, + "CoordinateSystem": CoordinateSystem, +} \ No newline at end of file diff --git a/code_to_optimize/code_directories/unstructured_example/elements.py b/code_to_optimize/code_directories/unstructured_example/elements.py new file mode 100644 index 000000000..44a1b35ef --- /dev/null +++ b/code_to_optimize/code_directories/unstructured_example/elements.py @@ -0,0 +1,989 @@ +from __future__ import annotations + +import abc +import copy +import dataclasses as dc +import enum +import hashlib +import os +import pathlib +import uuid +from itertools import groupby +from types import MappingProxyType +from typing import Any, Callable, FrozenSet, Optional, Sequence, cast + +from typing_extensions import ParamSpec, TypeAlias, TypedDict + +from coordinates import ( + TYPE_TO_COORDINATE_SYSTEM_MAP, + CoordinateSystem, + RelativeCoordinateSystem, +) + +Point: TypeAlias = "tuple[float, float]" +Points: TypeAlias = "tuple[Point, ...]" + + +@dc.dataclass +class DataSourceMetadata: + """Metadata fields that pertain to the data source of the document.""" + + url: Optional[str] = None + version: Optional[str] = None + record_locator: Optional[dict[str, Any]] = None # Values must be JSON-serializable + date_created: Optional[str] = None + date_modified: Optional[str] = None + date_processed: Optional[str] = None + permissions_data: Optional[list[dict[str, Any]]] = None + + def to_dict(self): + return {key: value for key, value in self.__dict__.items() if value is not None} + + @classmethod + def from_dict(cls, input_dict: dict[str, Any]): + # Only use existing fields when constructing + supported_fields = [f.name for f in dc.fields(cls)] + args = {k: v for k, v in input_dict.items() if k in supported_fields} + + return cls(**args) + + +@dc.dataclass +class CoordinatesMetadata: + """Metadata fields that pertain to the coordinates of the element.""" + + points: Optional[Points] + system: Optional[CoordinateSystem] + + def __init__(self, points: Optional[Points], system: Optional[CoordinateSystem]): + # Both `points` and `system` must be present; one is not meaningful without the other. + if (points is None and system is not None) or (points is not None and system is None): + raise ValueError( + "Coordinates points should not exist without coordinates system and vice versa.", + ) + self.points = points + self.system = system + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, CoordinatesMetadata): + return False + return all( + [ + (self.points == other.points), + (self.system == other.system), + ], + ) + + def to_dict(self): + return { + "points": self.points, + "system": None if self.system is None else str(self.system.__class__.__name__), + "layout_width": None if self.system is None else self.system.width, + "layout_height": None if self.system is None else self.system.height, + } + + @classmethod + def from_dict(cls, input_dict: dict[str, Any]): + # `input_dict` may contain a tuple of tuples or a list of lists + def convert_to_points(sequence_of_sequences: Sequence[Sequence[float]]) -> Points: + points: list[Point] = [] + for seq in sequence_of_sequences: + if isinstance(seq, list): + points.append(cast(Point, tuple(seq))) + elif isinstance(seq, tuple): + points.append(cast(Point, seq)) + return tuple(points) + + # -- parse points -- + input_points = input_dict.get("points") + points = convert_to_points(input_points) if input_points is not None else None + + # -- parse system -- + system_name = input_dict.get("system") + width = input_dict.get("layout_width") + height = input_dict.get("layout_height") + system = ( + None + if system_name is None + else ( + RelativeCoordinateSystem() + if system_name == "RelativeCoordinateSystem" + else ( + TYPE_TO_COORDINATE_SYSTEM_MAP[system_name](width, height) + if ( + width is not None + and height is not None + and system_name in TYPE_TO_COORDINATE_SYSTEM_MAP + ) + else None + ) + ) + ) + + return cls(points=points, system=system) + + +class Link(TypedDict): + """Metadata related to extracted links""" + + text: Optional[str] + url: str + start_index: int + + +class FormKeyOrValue(TypedDict): + text: str + layout_element_id: Optional[str] + custom_element: Optional[Text] + + +class FormKeyValuePair(TypedDict): + key: FormKeyOrValue + value: Optional[FormKeyOrValue] + confidence: float + + +class ElementMetadata: + """Fully-dynamic replacement for dataclass-based ElementMetadata.""" + + # NOTE(scanny): To add a field: + # - Add the field declaration with type here at the top. This makes it a "known" field and + # enables type-checking and completion. + # - Add a parameter with default for field in __init__() and assign it in __init__() body. + # - Add a consolidation strategy for the new field in + # `ConsolidationStrategy.field_consolidation_strategies()` below. This strategy will be used + # to consolidate this new metadata field from each pre-chunk element during chunking. + # - Add field-name to DEBUG_FIELD_NAMES if it shouldn't appear in dict/JSON or participate in + # equality comparison. + + attached_to_filename: Optional[str] + category_depth: Optional[int] + coordinates: Optional[CoordinatesMetadata] + data_source: Optional[DataSourceMetadata] + # -- Detection Model Class Probabilities from Unstructured-Inference Hi-Res -- + detection_class_prob: Optional[float] + # -- DEBUG field, the detection mechanism that emitted this element -- + detection_origin: Optional[str] + emphasized_text_contents: Optional[list[str]] + emphasized_text_tags: Optional[list[str]] + file_directory: Optional[str] + filename: Optional[str] + filetype: Optional[str] + image_url: Optional[str] + image_path: Optional[str] + image_base64: Optional[str] + image_mime_type: Optional[str] + # -- specific to DOCX which has distinct primary, first-page, and even-page header/footers -- + header_footer_type: Optional[str] + # -- used in chunks only, when chunk must be split mid-text to fit window -- + is_continuation: Optional[bool] + key_value_pairs: Optional[list[FormKeyValuePair]] + languages: Optional[list[str]] + last_modified: Optional[str] + link_texts: Optional[list[str]] + link_urls: Optional[list[str]] + link_start_indexes: Optional[list[int]] + links: Optional[list[Link]] + # -- used in chunks only, allowing access to element(s) chunk was formed from when enabled -- + orig_elements: Optional[list[Element]] + # -- the worksheet name in XLXS documents -- + page_name: Optional[str] + # -- page numbers currently supported for DOCX, HTML, PDF, and PPTX documents -- + page_number: Optional[int] + parent_id: Optional[str] + + # -- e-mail specific metadata fields -- + bcc_recipient: Optional[list[str]] + cc_recipient: Optional[list[str]] + email_message_id: Optional[str] + sent_from: Optional[list[str]] + sent_to: Optional[list[str]] + subject: Optional[str] + signature: Optional[str] + + # -- used for Table elements to capture rows/col structure -- + text_as_html: Optional[str] + table_as_cells: Optional[dict[str, str | int]] + url: Optional[str] + + # -- debug fields can be assigned and referenced using dotted-notation but are not serialized + # -- to dict/JSON, do not participate in equality comparison, and are not included in the + # -- `.fields` dict used by other parts of the library like chunking and weaviate. + DEBUG_FIELD_NAMES = frozenset(["detection_origin"]) + + def __init__( + self, + attached_to_filename: Optional[str] = None, + bcc_recipient: Optional[list[str]] = None, + category_depth: Optional[int] = None, + cc_recipient: Optional[list[str]] = None, + coordinates: Optional[CoordinatesMetadata] = None, + data_source: Optional[DataSourceMetadata] = None, + detection_class_prob: Optional[float] = None, + emphasized_text_contents: Optional[list[str]] = None, + emphasized_text_tags: Optional[list[str]] = None, + file_directory: Optional[str] = None, + filename: Optional[str | pathlib.Path] = None, + filetype: Optional[str] = None, + header_footer_type: Optional[str] = None, + image_base64: Optional[str] = None, + image_mime_type: Optional[str] = None, + image_url: Optional[str] = None, + image_path: Optional[str] = None, + is_continuation: Optional[bool] = None, + languages: Optional[list[str]] = None, + last_modified: Optional[str] = None, + link_start_indexes: Optional[list[int]] = None, + link_texts: Optional[list[str]] = None, + link_urls: Optional[list[str]] = None, + links: Optional[list[Link]] = None, + email_message_id: Optional[str] = None, + orig_elements: Optional[list[Element]] = None, + page_name: Optional[str] = None, + page_number: Optional[int] = None, + parent_id: Optional[str] = None, + sent_from: Optional[list[str]] = None, + sent_to: Optional[list[str]] = None, + signature: Optional[str] = None, + subject: Optional[str] = None, + table_as_cells: Optional[dict[str, str | int]] = None, + text_as_html: Optional[str] = None, + url: Optional[str] = None, + ) -> None: + self.attached_to_filename = attached_to_filename + self.bcc_recipient = bcc_recipient + self.category_depth = category_depth + self.cc_recipient = cc_recipient + self.coordinates = coordinates + self.data_source = data_source + self.detection_class_prob = detection_class_prob + self.emphasized_text_contents = emphasized_text_contents + self.emphasized_text_tags = emphasized_text_tags + + # -- accommodate pathlib.Path for filename -- + filename = str(filename) if isinstance(filename, pathlib.Path) else filename + # -- produces "", "" when filename arg is None -- + directory_path, file_name = os.path.split(filename or "") + # -- prefer `file_directory` arg if specified, otherwise split of file-path passed as + # -- `filename` arg, or None if `filename` is the empty string. + self.file_directory = file_directory or directory_path or None + self.filename = file_name or None + + self.filetype = filetype + self.header_footer_type = header_footer_type + self.image_base64 = image_base64 + self.image_mime_type = image_mime_type + self.image_url = image_url + self.image_path = image_path + self.is_continuation = is_continuation + self.languages = languages + self.last_modified = last_modified + self.link_texts = link_texts + self.link_urls = link_urls + self.link_start_indexes = link_start_indexes + self.links = links + self.email_message_id = email_message_id + self.orig_elements = orig_elements + self.page_name = page_name + self.page_number = page_number + self.parent_id = parent_id + self.sent_from = sent_from + self.sent_to = sent_to + self.signature = signature + self.subject = subject + self.text_as_html = text_as_html + self.table_as_cells = table_as_cells + self.url = url + + @classmethod + def from_dict(cls, meta_dict: dict[str, Any]) -> ElementMetadata: + """Construct from a metadata-dict. + + This would generally be a dict formed using the `.to_dict()` method and stored as JSON + before "rehydrating" it using this method. + """ + from base import elements_from_base64_gzipped_json + + # -- avoid unexpected mutation by working on a copy of provided dict -- + meta_dict = copy.deepcopy(meta_dict) + self = ElementMetadata() + for field_name, field_value in meta_dict.items(): + if field_name == "coordinates": + self.coordinates = CoordinatesMetadata.from_dict(field_value) + elif field_name == "data_source": + self.data_source = DataSourceMetadata.from_dict(field_value) + elif field_name == "orig_elements": + self.orig_elements = elements_from_base64_gzipped_json(field_value) + elif field_name == "key_value_pairs": + self.key_value_pairs = _kvform_rehydrate_internal_elements(field_value) + else: + setattr(self, field_name, field_value) + + return self + + @property + def fields(self) -> MappingProxyType[str, Any]: + """Populated metadata fields in this object as a read-only dict. + + Basically `self.__dict__` but it needs a little filtering to remove entries like + "_known_field_names". Note this is a *snapshot* and will not reflect later changes. + """ + return MappingProxyType( + { + field_name: field_value + for field_name, field_value in self.__dict__.items() + if not field_name.startswith("_") and field_name not in self.DEBUG_FIELD_NAMES + } + ) + + @property + def known_fields(self) -> MappingProxyType[str, Any]: + """Populated non-ad-hoc fields in this object as a read-only dict. + + Only fields declared at the top of this class are included. Ad-hoc fields added to this + instance by assignment are not. Note this is a *snapshot* and will not reflect changes that + occur after this call. + """ + known_field_names = self._known_field_names + return MappingProxyType( + { + field_name: field_value + for field_name, field_value in self.__dict__.items() + if (field_name in known_field_names and field_name not in self.DEBUG_FIELD_NAMES) + } + ) + + def to_dict(self) -> dict[str, Any]: + """Convert this metadata to dict form, suitable for JSON serialization. + + The returned dict is "sparse" in that no key-value pair appears for a field with value + `None`. + """ + from base import elements_to_base64_gzipped_json + + meta_dict = copy.deepcopy(dict(self.fields)) + + # -- remove fields that should not be serialized -- + for field_name in self.DEBUG_FIELD_NAMES: + meta_dict.pop(field_name, None) + + # -- don't serialize empty lists -- + meta_dict: dict[str, Any] = { + field_name: value + for field_name, value in meta_dict.items() + if value != [] and value != {} + } + + # -- serialize sub-object types when present -- + if self.coordinates is not None: + meta_dict["coordinates"] = self.coordinates.to_dict() + if self.data_source is not None: + meta_dict["data_source"] = self.data_source.to_dict() + if self.orig_elements is not None: + meta_dict["orig_elements"] = elements_to_base64_gzipped_json(self.orig_elements) + if self.key_value_pairs is not None: + meta_dict["key_value_pairs"] = _kvform_pairs_to_dict(self.key_value_pairs) + + return meta_dict + + def update(self, other: ElementMetadata) -> None: + """Update self with all fields present in `other`. + + Semantics are like those of `dict.update()`. + + - fields present in both `self` and `other` will be updated to the value in `other`. + - fields present in `other` but not `self` will be added to `self`. + - fields present in `self` but not `other` are unchanged. + - `other` is unchanged. + - both ad-hoc and known fields participate in update with the same semantics. + + Note that fields listed in DEBUG_FIELD_NAMES are skipped in this process. Those can only be + updated by direct assignment to the instance. + """ + if not isinstance(other, ElementMetadata): # pyright: ignore[reportUnnecessaryIsInstance] + raise ValueError("argument to '.update()' must be an instance of 'ElementMetadata'") + + for field_name, field_value in other.fields.items(): + setattr(self, field_name, field_value) + + def _known_field_names(self) -> FrozenSet[str]: + return frozenset(self.__annotations__) + + +class ConsolidationStrategy(enum.Enum): + """Methods by which a metadata field can be consolidated across a collection of elements. + + These are assigned to `ElementMetadata` field-names immediately below. Metadata consolidation is + part of the chunking process and may arise elsewhere as well. + """ + + DROP = "drop" + """Do not include this field in the consolidated metadata object.""" + + FIRST = "first" + """Use the first value encountered, omit if not present in any elements.""" + + STRING_CONCATENATE = "string_concatenate" + """Combine the values of this field across elements. Only suitable for fields of `str` type.""" + + LIST_CONCATENATE = "LIST_CONCATENATE" + """Concatenate the list values across elements. Only suitable for fields of `List` type.""" + + LIST_UNIQUE = "list_unique" + """Union list values across elements, preserving order. Only suitable for `List` fields.""" + + @classmethod + def field_consolidation_strategies(cls) -> dict[str, ConsolidationStrategy]: + """Mapping from ElementMetadata field-name to its consolidation strategy. + + Note that only _TextSection objects ("pre-chunks" containing only `Text` elements that are + not `Table`) have their metadata consolidated, so these strategies are only applicable for + non-Table Text elements. + """ + return { + "attached_to_filename": cls.FIRST, + "cc_recipient": cls.FIRST, + "bcc_recipient": cls.FIRST, + "category_depth": cls.DROP, + "coordinates": cls.DROP, + "data_source": cls.FIRST, + "detection_class_prob": cls.DROP, + "detection_origin": cls.DROP, + "emphasized_text_contents": cls.LIST_CONCATENATE, + "emphasized_text_tags": cls.LIST_CONCATENATE, + "file_directory": cls.FIRST, + "filename": cls.FIRST, + "filetype": cls.FIRST, + "header_footer_type": cls.DROP, + "image_url": cls.DROP, + "image_path": cls.DROP, + "image_base64": cls.DROP, + "image_mime_type": cls.DROP, + "is_continuation": cls.DROP, # -- not expected, added by chunking, not before -- + "languages": cls.LIST_UNIQUE, + "last_modified": cls.FIRST, + "link_texts": cls.LIST_CONCATENATE, + "link_urls": cls.LIST_CONCATENATE, + "link_start_indexes": cls.DROP, + "links": cls.DROP, # -- deprecated field -- + "email_message_id": cls.FIRST, + "max_characters": cls.DROP, # -- unused, remove from ElementMetadata -- + "orig_elements": cls.DROP, # -- not expected, added by chunking, not before -- + "page_name": cls.FIRST, + "page_number": cls.FIRST, + "parent_id": cls.DROP, + "sent_from": cls.FIRST, + "sent_to": cls.FIRST, + "signature": cls.FIRST, + "subject": cls.FIRST, + "text_as_html": cls.STRING_CONCATENATE, + "table_as_cells": cls.FIRST, # -- only occurs in Table -- + "url": cls.FIRST, + "key_value_pairs": cls.DROP, # -- only occurs in FormKeysValues -- + } + + +_P = ParamSpec("_P") + + +def assign_and_map_hash_ids(elements: list[Element]) -> list[Element]: + """Converts `id` and `parent_id` of elements from UUIDs to hashes. + + This function ensures deterministic IDs by: + 1. Converting each element's UUID into a hash. + 2. Updating the `parent_id` to match the new hash ID of parent elements. + + Args: + elements: A list of Element objects to update. + + Returns: + List of updated Element objects with hashes for `id` and `parent_id`. + """ + # -- generate sequence number for each element on a page -- + page_numbers = [e.metadata.page_number for e in elements] + page_seq_pairs = [ + seq_on_page for _, group in groupby(page_numbers) for seq_on_page, _ in enumerate(group) + ] + + # -- assign hash IDs to elements -- + old_to_new_mapping = { + element.id: element.id_to_hash(seq_on_page_counter) + for element, seq_on_page_counter in zip(elements, page_seq_pairs) + } + + # -- map old parent IDs to new ones -- + for e in elements: + parent_id = e.metadata.parent_id + if not parent_id or parent_id not in old_to_new_mapping: + continue + e.metadata.parent_id = old_to_new_mapping[parent_id] + + return elements + + +class ElementType: + TITLE = "Title" + TEXT = "Text" + UNCATEGORIZED_TEXT = "UncategorizedText" + NARRATIVE_TEXT = "NarrativeText" + BULLETED_TEXT = "BulletedText" + PARAGRAPH = "Paragraph" + ABSTRACT = "Abstract" + THREADING = "Threading" + FORM = "Form" + FIELD_NAME = "Field-Name" + VALUE = "Value" + LINK = "Link" + COMPOSITE_ELEMENT = "CompositeElement" + IMAGE = "Image" + PICTURE = "Picture" + FIGURE_CAPTION = "FigureCaption" + FIGURE = "Figure" + CAPTION = "Caption" + LIST = "List" + LIST_ITEM = "ListItem" + LIST_ITEM_OTHER = "List-item" + CHECKED = "Checked" + UNCHECKED = "Unchecked" + CHECK_BOX_CHECKED = "CheckBoxChecked" + CHECK_BOX_UNCHECKED = "CheckBoxUnchecked" + RADIO_BUTTON_CHECKED = "RadioButtonChecked" + RADIO_BUTTON_UNCHECKED = "RadioButtonUnchecked" + ADDRESS = "Address" + EMAIL_ADDRESS = "EmailAddress" + PAGE_BREAK = "PageBreak" + FORMULA = "Formula" + TABLE = "Table" + HEADER = "Header" + HEADLINE = "Headline" + SUB_HEADLINE = "Subheadline" + PAGE_HEADER = "Page-header" # Title? + SECTION_HEADER = "Section-header" + FOOTER = "Footer" + FOOTNOTE = "Footnote" + PAGE_FOOTER = "Page-footer" + PAGE_NUMBER = "PageNumber" + CODE_SNIPPET = "CodeSnippet" + FORM_KEYS_VALUES = "FormKeysValues" + DOCUMENT_DATA = "DocumentData" + + @classmethod + def to_dict(cls): + """ + Convert class attributes to a dictionary. + + Returns: + dict: A dictionary where keys are attribute names and values are attribute values. + """ + return { + attr: getattr(cls, attr) + for attr in dir(cls) + if not callable(getattr(cls, attr)) and not attr.startswith("__") + } + + +class Element(abc.ABC): + """An element is a semantically-coherent component of a document, often a paragraph. + + There are a few design principles that are followed when creating an element: + 1. It will always have an ID, which by default is a random UUID. + 2. Asking for an ID should always return a string, it can never be None. + 3. ID is lazy, meaning it will be generated when asked for the first time. + 4. When deterministic behavior is needed, the ID can be converted. + to a hash based on its text `element.id_to_hash(position)` + 4. Even if the `text` attribute is not defined in a subclass, it will default to a blank string. + 6. Assigning a string ID manually is possible, but is meant to be used + only for deserialization purposes. + """ + + text: str + category = "UncategorizedText" + + def __init__( + self, + element_id: Optional[str] = None, + coordinates: Optional[tuple[tuple[float, float], ...]] = None, + coordinate_system: Optional[CoordinateSystem] = None, + metadata: Optional[ElementMetadata] = None, + detection_origin: Optional[str] = None, + ): + if element_id is not None and not isinstance(element_id, str): # type: ignore + raise ValueError("element_id must be of type str or None.") + + self._element_id = element_id + self.metadata = ElementMetadata() if metadata is None else metadata + if coordinates is not None or coordinate_system is not None: + self.metadata.coordinates = CoordinatesMetadata( + points=coordinates, system=coordinate_system + ) + self.metadata.detection_origin = detection_origin + # -- all `Element` instances get a `text` attribute, defaults to the empty string if not + # -- defined in a subclass. + self.text = self.text if hasattr(self, "text") else "" + + def __str__(self): + return self.text + + def convert_coordinates_to_new_system( + self, new_system: CoordinateSystem, in_place: bool = True + ) -> Optional[Points]: + """Converts the element location coordinates to a new coordinate system. + + If inplace is true, changes the coordinates in place and updates the coordinate system. + """ + if ( + self.metadata.coordinates is None + or self.metadata.coordinates.system is None + or self.metadata.coordinates.points is None + ): + return None + + new_coordinates = tuple( + self.metadata.coordinates.system.convert_coordinates_to_new_system( + new_system=new_system, + x=x, + y=y, + ) + for x, y in self.metadata.coordinates.points + ) + + if in_place: + self.metadata.coordinates.points = new_coordinates + self.metadata.coordinates.system = new_system + + return new_coordinates + + def id_to_hash(self, sequence_number: int) -> str: + """Calculates and assigns a deterministic hash as an ID. + + The hash ID is based on element's text, sequence number on page, + page number and its filename. + + Args: + sequence_number: index on page + + Returns: new ID value + """ + data = f"{self.metadata.filename}{self.text}{self.metadata.page_number}{sequence_number}" + self._element_id = hashlib.sha256(data.encode()).hexdigest()[:32] + return self.id + + @property + def id(self): + if self._element_id is None: + self._element_id = str(uuid.uuid4()) + return self._element_id + + def to_dict(self) -> dict[str, Any]: + return { + "type": None, + "element_id": self.id, + "text": self.text, + "metadata": self.metadata.to_dict(), + } + + +class CheckBox(Element): + """A checkbox with an attribute indicating whether its checked or not. + + Primarily used in documents that are forms. + """ + + def __init__( + self, + element_id: Optional[str] = None, + coordinates: Optional[tuple[tuple[float, float], ...]] = None, + coordinate_system: Optional[CoordinateSystem] = None, + checked: bool = False, + metadata: Optional[ElementMetadata] = None, + detection_origin: Optional[str] = None, + ): + metadata = metadata if metadata else ElementMetadata() + super().__init__( + element_id=element_id, + coordinates=coordinates, + coordinate_system=coordinate_system, + metadata=metadata, + detection_origin=detection_origin, + ) + self.checked: bool = checked + + def __eq__(self, other: object) -> bool: + if not isinstance(other, CheckBox): + return False + return all( + ( + self.checked == other.checked, + self.metadata.coordinates == other.metadata.coordinates, + ) + ) + + def to_dict(self) -> dict[str, Any]: + """Serialize to JSON-compatible (str keys) dict.""" + out = super().to_dict() + out["type"] = "CheckBox" + out["checked"] = self.checked + out["element_id"] = self.id + return out + + +class Text(Element): + """Base element for capturing free text from within document.""" + + def __init__( + self, + text: str, + element_id: Optional[str] = None, + coordinates: Optional[tuple[tuple[float, float], ...]] = None, + coordinate_system: Optional[CoordinateSystem] = None, + metadata: Optional[ElementMetadata] = None, + detection_origin: Optional[str] = None, + embeddings: Optional[list[float]] = None, + ): + metadata = metadata if metadata else ElementMetadata() + self.text: str = text + self.embeddings: Optional[list[float]] = embeddings + + super().__init__( + element_id=element_id, + metadata=metadata, + coordinates=coordinates, + coordinate_system=coordinate_system, + detection_origin=detection_origin, + ) + + def __eq__(self, other: object): + if not isinstance(other, Text): + return False + return all( + ( + self.text == other.text, + self.metadata.coordinates == other.metadata.coordinates, + self.category == other.category, + self.embeddings == other.embeddings, + ), + ) + + def __str__(self): + return self.text + + def apply(self, *cleaners: Callable[[str], str]): + """Applies a cleaning brick to the text element. + + The function that's passed in should take a string as input and produce a string as + output. + """ + cleaned_text = self.text + for cleaner in cleaners: + cleaned_text = cleaner(cleaned_text) + + if not isinstance(cleaned_text, str): # pyright: ignore[reportUnnecessaryIsInstance] + raise ValueError("Cleaner produced a non-string output.") + + self.text = cleaned_text + + def to_dict(self) -> dict[str, Any]: + """Serialize to JSON-compatible (str keys) dict.""" + out = super().to_dict() + out["element_id"] = self.id + out["type"] = self.category + out["text"] = self.text + if self.embeddings: + out["embeddings"] = self.embeddings + return out + + +class Formula(Text): + "An element containing formulas in a document" + + category = "Formula" + + +class CompositeElement(Text): + """A chunk formed from text (non-Table) elements. + + Only produced by chunking. An instance may be formed by combining one or more sequential + elements produced by partitioning. It it also used when text-splitting an "oversized" element, + a single element that by itself is larger than the requested chunk size. + """ + + category = "CompositeElement" + + +class FigureCaption(Text): + """An element for capturing text associated with figure captions.""" + + category = "FigureCaption" + + +class NarrativeText(Text): + """NarrativeText is an element consisting of multiple, well-formulated sentences. This + excludes elements such titles, headers, footers, and captions.""" + + category = "NarrativeText" + + +class ListItem(Text): + """ListItem is a NarrativeText element that is part of a list.""" + + category = "ListItem" + + +class Title(Text): + """A text element for capturing titles.""" + + category = "Title" + + +class Address(Text): + """A text element for capturing addresses.""" + + category = "Address" + + +class EmailAddress(Text): + """A text element for capturing addresses""" + + category = "EmailAddress" + + +class Image(Text): + """A text element for capturing image metadata.""" + + category = ElementType.IMAGE + + +class PageBreak(Text): + """An element for capturing page breaks.""" + + category = "PageBreak" + + +class Table(Text): + """An element for capturing tables.""" + + category = "Table" + + +class TableChunk(Table): + """An element for capturing chunks of tables.""" + + category = "Table" + + +class Header(Text): + """An element for capturing document headers.""" + + category = "Header" + + +class Footer(Text): + """An element for capturing document footers.""" + + category = "Footer" + + +class CodeSnippet(Text): + """An element for capturing code snippets.""" + + category = "CodeSnippet" + + +class PageNumber(Text): + """An element for capturing page numbers.""" + + category = "PageNumber" + + +class FormKeysValues(Text): + """An element for capturing Key-Value dicts (forms).""" + + category = "FormKeysValues" + + +class DocumentData(Text): + """An element for capturing document-level data, + particularly for large data that does not make sense to + represent across each element in the document.""" + + category = "DocumentData" + + +TYPE_TO_TEXT_ELEMENT_MAP: dict[str, type[Text]] = { + ElementType.TITLE: Title, + ElementType.SECTION_HEADER: Title, + ElementType.HEADLINE: Title, + ElementType.SUB_HEADLINE: Title, + ElementType.FIELD_NAME: Title, + ElementType.UNCATEGORIZED_TEXT: Text, + ElementType.COMPOSITE_ELEMENT: CompositeElement, + ElementType.TEXT: NarrativeText, + ElementType.NARRATIVE_TEXT: NarrativeText, + ElementType.PARAGRAPH: NarrativeText, + # this mapping favors ensures yolox produces backward compatible categories + ElementType.ABSTRACT: NarrativeText, + ElementType.THREADING: NarrativeText, + ElementType.FORM: NarrativeText, + ElementType.VALUE: NarrativeText, + ElementType.LINK: NarrativeText, + ElementType.LIST_ITEM: ListItem, + ElementType.BULLETED_TEXT: ListItem, + ElementType.LIST_ITEM_OTHER: ListItem, + ElementType.HEADER: Header, + ElementType.PAGE_HEADER: Header, # Title? + ElementType.FOOTER: Footer, + ElementType.PAGE_FOOTER: Footer, + ElementType.FOOTNOTE: Footer, + ElementType.FIGURE_CAPTION: FigureCaption, + ElementType.CAPTION: FigureCaption, + ElementType.IMAGE: Image, + ElementType.FIGURE: Image, + ElementType.PICTURE: Image, + ElementType.TABLE: Table, + ElementType.ADDRESS: Address, + ElementType.EMAIL_ADDRESS: EmailAddress, + ElementType.FORMULA: Formula, + ElementType.PAGE_BREAK: PageBreak, + ElementType.CODE_SNIPPET: CodeSnippet, + ElementType.PAGE_NUMBER: PageNumber, + ElementType.FORM_KEYS_VALUES: FormKeysValues, + ElementType.DOCUMENT_DATA: DocumentData, +} + + +def _kvform_rehydrate_internal_elements(kv_pairs: list[dict[str, Any]]) -> list[FormKeyValuePair]: + """ + The key_value_pairs metadata field contains (in the vast majority of cases) + nested Text elements. Those need to be turned from dicts into Elements explicitly, + e.g. when partition_json is used. + """ + from base import elements_from_dicts + + # safe to overwrite - deepcopy already happened + for kv_pair in kv_pairs: + if kv_pair["key"]["custom_element"] is not None: + (kv_pair["key"]["custom_element"],) = elements_from_dicts( + [kv_pair["key"]["custom_element"]] + ) + if kv_pair["value"] is not None and kv_pair["value"]["custom_element"] is not None: + (kv_pair["value"]["custom_element"],) = elements_from_dicts( + [kv_pair["value"]["custom_element"]] + ) + return cast(list[FormKeyValuePair], kv_pairs) + + +def _kvform_pairs_to_dict(orig_kv_pairs: list[FormKeyValuePair]) -> list[dict[str, Any]]: + """ + The key_value_pairs metadata field contains (in the vast majority of cases) + nested Text elements. Those need to be turned from Elements to dicts recursively, + e.g. when FormKeysValues.to_dict() is used. + + """ + kv_pairs: list[dict[str, Any]] = copy.deepcopy(orig_kv_pairs) # type: ignore + for kv_pair in kv_pairs: + if kv_pair["key"]["custom_element"] is not None: + kv_pair["key"]["custom_element"] = kv_pair["key"]["custom_element"].to_dict() + if kv_pair["value"] is not None and kv_pair["value"]["custom_element"] is not None: + kv_pair["value"]["custom_element"] = kv_pair["value"]["custom_element"].to_dict() + + return kv_pairs \ No newline at end of file diff --git a/code_to_optimize/code_directories/unstructured_example/optimized.py b/code_to_optimize/code_directories/unstructured_example/optimized.py new file mode 100644 index 000000000..dab1bf23b --- /dev/null +++ b/code_to_optimize/code_directories/unstructured_example/optimized.py @@ -0,0 +1,165 @@ +from __future__ import annotations +import os +import pathlib +from typing import Any, Iterable, Optional + +from elements import (TYPE_TO_TEXT_ELEMENT_MAP, CheckBox, + Element) +from elements import ElementMetadata as _ElementMetadata + + +# Helper to resolve 'pathlib.Path' on filename efficiently +def _extract_file_directory_and_name(filename: Optional[str | pathlib.Path], file_directory: Optional[str]) -> tuple[Optional[str], Optional[str]]: + if isinstance(filename, pathlib.Path): + filename = str(filename) + directory_path, file_name = os.path.split(filename or "") + return (file_directory or directory_path or None, file_name or None) + +class ElementMetadata: + def __init__( + self, + attached_to_filename: Optional[str] = None, + bcc_recipient: Optional[list[str]] = None, + category_depth: Optional[int] = None, + cc_recipient: Optional[list[str]] = None, + coordinates: Optional[Any] = None, + data_source: Optional[Any] = None, + detection_class_prob: Optional[float] = None, + emphasized_text_contents: Optional[list[str]] = None, + emphasized_text_tags: Optional[list[str]] = None, + file_directory: Optional[str] = None, + filename: Optional[str | pathlib.Path] = None, + filetype: Optional[str] = None, + header_footer_type: Optional[str] = None, + image_base64: Optional[str] = None, + image_mime_type: Optional[str] = None, + image_url: Optional[str] = None, + image_path: Optional[str] = None, + is_continuation: Optional[bool] = None, + languages: Optional[list[str]] = None, + last_modified: Optional[str] = None, + link_start_indexes: Optional[list[int]] = None, + link_texts: Optional[list[str]] = None, + link_urls: Optional[list[str]] = None, + links: Optional[list[Any]] = None, + email_message_id: Optional[str] = None, + orig_elements: Optional[list[Element]] = None, + page_name: Optional[str] = None, + page_number: Optional[int] = None, + parent_id: Optional[str] = None, + sent_from: Optional[list[str]] = None, + sent_to: Optional[list[str]] = None, + signature: Optional[str] = None, + subject: Optional[str] = None, + table_as_cells: Optional[dict[str, str | int]] = None, + text_as_html: Optional[str] = None, + url: Optional[str] = None, + key_value_pairs: Optional[list[Any]] = None, + ) -> None: + self.attached_to_filename = attached_to_filename + self.bcc_recipient = bcc_recipient + self.category_depth = category_depth + self.cc_recipient = cc_recipient + self.coordinates = coordinates + self.data_source = data_source + self.detection_class_prob = detection_class_prob + self.emphasized_text_contents = emphasized_text_contents + self.emphasized_text_tags = emphasized_text_tags + + # -- accommodate pathlib.Path for filename -- + self.file_directory, self.filename = _extract_file_directory_and_name(filename, file_directory) + self.filetype = filetype + self.header_footer_type = header_footer_type + self.image_base64 = image_base64 + self.image_mime_type = image_mime_type + self.image_url = image_url + self.image_path = image_path + self.is_continuation = is_continuation + self.languages = languages + self.last_modified = last_modified + self.link_texts = link_texts + self.link_urls = link_urls + self.link_start_indexes = link_start_indexes + self.links = links + self.email_message_id = email_message_id + self.orig_elements = orig_elements + self.page_name = page_name + self.page_number = page_number + self.parent_id = parent_id + self.sent_from = sent_from + self.sent_to = sent_to + self.signature = signature + self.subject = subject + self.text_as_html = text_as_html + self.table_as_cells = table_as_cells + self.url = url + self.key_value_pairs = key_value_pairs + + @classmethod + def from_dict(cls, meta_dict: dict[str, Any]) -> 'ElementMetadata': + """Construct from a metadata-dict. + + This would generally be a dict formed using the `.to_dict()` method and stored as JSON + before "rehydrating" it using this method. + """ + from base import elements_from_base64_gzipped_json + + # Rather than copy.deepcopy, build fast new fields dict + key_value = meta_dict.get + + # Local import avoids import cycles (as originally intended) + coords_val = key_value("coordinates") + coordinates = CoordinatesMetadata.from_dict(coords_val) if coords_val is not None else None + + data_source_val = key_value("data_source") + data_source = DataSourceMetadata.from_dict(data_source_val) if data_source_val is not None else None + + orig_elements_val = key_value("orig_elements") + orig_elements = ( + elements_from_base64_gzipped_json(orig_elements_val) if orig_elements_val is not None else None + ) + + key_value_pairs_val = key_value("key_value_pairs") + key_value_pairs = ( + _kvform_rehydrate_internal_elements(key_value_pairs_val) if key_value_pairs_val is not None else None + ) + + # Build argument dict for __init__ using known args and self assignment for remaining + # Fast field assignment - all remaining fields + args = { + k: v for k, v in meta_dict.items() + if k not in ("coordinates", "data_source", "orig_elements", "key_value_pairs") + } + args["coordinates"] = coordinates + args["data_source"] = data_source + args["orig_elements"] = orig_elements + args["key_value_pairs"] = key_value_pairs + + return cls(**args) + +def elements_from_dicts(element_dicts: Iterable[dict[str, Any]]) -> list[Element]: + """Convert a list of element-dicts to a list of elements.""" + # Localize references for speed + ETM_MAP = TYPE_TO_TEXT_ELEMENT_MAP + result_append = [] + result_append_method = result_append.append # avoid attribute lookup inside loop + CheckBoxClass = CheckBox + + for item in element_dicts: + itype = item.get("type") + element_id = item.get("element_id") + meta = item.get("metadata") + metadata = ElementMetadata() if meta is None else ElementMetadata.from_dict(meta) + if itype in ETM_MAP: + ElementCls = ETM_MAP[itype] + result_append_method(ElementCls( + text=item["text"], element_id=element_id, metadata=metadata + )) + elif itype == "CheckBox": + result_append_method( + CheckBoxClass( + checked=item["checked"], element_id=element_id, metadata=metadata + ) + ) + return result_append + diff --git a/code_to_optimize/code_directories/unstructured_example/pyproject.toml b/code_to_optimize/code_directories/unstructured_example/pyproject.toml new file mode 100644 index 000000000..bddef0ed3 --- /dev/null +++ b/code_to_optimize/code_directories/unstructured_example/pyproject.toml @@ -0,0 +1,7 @@ +[tool.codeflash] +# All paths are relative to this pyproject.toml's directory. +module-root = "." +tests-root = "tests" +test-framework = "pytest" +ignore-paths = [] +formatter-cmds = ["black $file"] diff --git a/code_to_optimize/code_directories/unstructured_example/utils.py b/code_to_optimize/code_directories/unstructured_example/utils.py new file mode 100644 index 000000000..1610a1968 --- /dev/null +++ b/code_to_optimize/code_directories/unstructured_example/utils.py @@ -0,0 +1,8 @@ +from typing import ( + Tuple, +) +from typing_extensions import TypeAlias + +Box: TypeAlias = Tuple[float, float, float, float] +Point: TypeAlias = Tuple[float, float] +Points: TypeAlias = Tuple[Point, ...] \ No newline at end of file diff --git a/codeflash/code_utils/code_replacer.py b/codeflash/code_utils/code_replacer.py index 3c73c5919..31b9f9465 100644 --- a/codeflash/code_utils/code_replacer.py +++ b/codeflash/code_utils/code_replacer.py @@ -354,11 +354,68 @@ def leave_Module(self, original_node: cst.Module, updated_node: cst.Module) -> c return node +class FunctionUsageCollector(cst.CSTVisitor): + def __init__(self) -> None: + self.used_functions: set[str] = set() + + def visit_Call(self, node: cst.Call) -> bool: + if isinstance(node.func, cst.Name): + self.used_functions.add(node.func.value) + return True + + +class UnusedFunctionRemover(cst.CSTTransformer): + def __init__(self, unused_function_names: set[str]) -> None: + self.unused_function_names = unused_function_names + + def leave_FunctionDef( + self, original_node: cst.FunctionDef, updated_node: cst.FunctionDef + ) -> cst.RemovalSentinel | cst.FunctionDef: + if original_node.name.value in self.unused_function_names: + return cst.RemoveFromParent() + return updated_node + + +class ImportTracker(cst.CSTVisitor): + def __init__(self) -> None: + # Set of fully-qualified imports like "utils.helpers.extract_path" + self.imported_names: set[str] = set() + + def get_full_name(self, expr: cst.BaseExpression) -> str: + if isinstance(expr, cst.Name): + return expr.value + if isinstance(expr, cst.Attribute): + return f"{self.get_full_name(expr.value)}.{expr.attr.value}" + return "" + + def visit_ImportFrom(self, node: cst.ImportFrom) -> None: + # Get the full module path like "utils.helpers" + module = ("." * node.relative if node.relative else "") + ( + node.module.attr.value + if isinstance(node.module, cst.Attribute) + else node.module.value + if node.module + else "" + ) + + for name in node.names: + if isinstance(name, cst.ImportAlias): + imported_name = name.name.value + self.imported_names.add(f"{module}.{imported_name}") + + def visit_Import(self, node: cst.Import) -> None: + for name in node.names: + if isinstance(name, cst.ImportAlias): + full_module_path = self.get_full_name(name.name) + self.imported_names.add(full_module_path) + + def replace_functions_in_file( source_code: str, original_function_names: list[str], optimized_code: str, preexisting_objects: set[tuple[str, tuple[FunctionParent, ...]]], + module_abspath: Optional[Path] = None, ) -> str: parsed_function_names = [] for original_function_name in original_function_names: @@ -386,7 +443,49 @@ def replace_functions_in_file( ) original_module = cst.parse_module(source_code) modified_tree = original_module.visit(transformer) - return modified_tree.code + + module_name = module_abspath.name.split(".")[0] if module_abspath is not None else None + + if module_name is None: + return modified_tree.code + + unused_new_function_names = [] + + file_uses = FunctionUsageCollector() + modified_tree.visit(file_uses) + + new_functions = visitor.new_functions + used_functions = file_uses.used_functions + + import_tracker = ImportTracker() + module.visit(import_tracker) + + # get unused new functions (not in used_functions && not in the imported from module_abs.new_fucntion) + for fn in new_functions: + fn_name = fn.name.value + is_imported_from_this_module = f"{module_name}.{fn_name}" in import_tracker.imported_names + if is_imported_from_this_module: + # keep the new function in the module as it's imported else where + continue + + # note: can_be_imported not just means it's imported in the optimized context but also it's imported from other module because of the condition above + can_be_imported = False + for _import in import_tracker.imported_names: + if _import.endswith("." + fn_name): + can_be_imported = True + break + + used_and_can_be_imported = fn_name in used_functions and can_be_imported + not_used_and_no_import_for_it = fn_name not in used_functions and not can_be_imported + + if used_and_can_be_imported or not_used_and_no_import_for_it: + # we then going to remove it + unused_new_function_names.append(fn_name) + + remover = UnusedFunctionRemover(unused_new_function_names) + cleaned_tree = modified_tree.visit(remover) + + return cleaned_tree.code def replace_functions_and_add_imports( @@ -399,7 +498,7 @@ def replace_functions_and_add_imports( ) -> str: return add_needed_imports_from_module( optimized_code, - replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects), + replace_functions_in_file(source_code, function_names, optimized_code, preexisting_objects, module_abspath), module_abspath, module_abspath, project_root_path, diff --git a/codeflash/lsp/server.py b/codeflash/lsp/server.py index 1a303c13c..a018786a8 100644 --- a/codeflash/lsp/server.py +++ b/codeflash/lsp/server.py @@ -3,7 +3,7 @@ from pathlib import Path from typing import TYPE_CHECKING, Any -from lsprotocol.types import INITIALIZE, MessageType, LogMessageParams +from lsprotocol.types import INITIALIZE, LogMessageParams, MessageType from pygls import uris from pygls.protocol import LanguageServerProtocol, lsp_method from pygls.server import LanguageServer @@ -58,21 +58,22 @@ def initialize_optimizer(self, config_file: Path) -> None: def show_message_log(self, message: str, message_type: str) -> None: """Send a log message to the client's output channel. - + Args: message: The message to log message_type: String type - "Info", "Warning", "Error", or "Log" + """ # Convert string message type to LSP MessageType enum type_mapping = { "Info": MessageType.Info, - "Warning": MessageType.Warning, + "Warning": MessageType.Warning, "Error": MessageType.Error, - "Log": MessageType.Log + "Log": MessageType.Log, } - + lsp_message_type = type_mapping.get(message_type, MessageType.Info) - + # Send log message to client (appears in output channel) log_params = LogMessageParams(type=lsp_message_type, message=message) self.lsp.notify("window/logMessage", log_params) diff --git a/codeflash/lsp/server_entry.py b/codeflash/lsp/server_entry.py index 841d18f84..bae00bcb8 100644 --- a/codeflash/lsp/server_entry.py +++ b/codeflash/lsp/server_entry.py @@ -1,10 +1,9 @@ -"""This script is the dedicated entry point for the Codeflash Language Server. -It initializes the server and redirects its logs to stderr so that the -VS Code client can display them in the output channel. +# This script is the dedicated entry point for the Codeflash Language Server. +# It initializes the server and redirects its logs to stderr so that the +# VS Code client can display them in the output channel. -This script is run by the VS Code extension and is not intended to be -executed directly by users. -""" +# This script is run by the VS Code extension and is not intended to be +# executed directly by users. import logging import sys @@ -13,7 +12,7 @@ # Configure logging to stderr for VS Code output channel -def setup_logging(): +def setup_logging() -> logging.Logger: # Clear any existing handlers to prevent conflicts root_logger = logging.getLogger() root_logger.handlers.clear() diff --git a/codeflash/optimization/function_optimizer.py b/codeflash/optimization/function_optimizer.py index 82cf4bc57..3a1cfffa2 100644 --- a/codeflash/optimization/function_optimizer.py +++ b/codeflash/optimization/function_optimizer.py @@ -585,6 +585,8 @@ def write_code_and_helpers(original_code: str, original_helper_code: dict[Path, with path.open("w", encoding="utf8") as f: f.write(original_code) for module_abspath, helper_code in original_helper_code.items(): + if module_abspath == path: + continue # no need to write it again with Path(module_abspath).open("w", encoding="utf8") as f: f.write(helper_code) @@ -632,7 +634,6 @@ def replace_function_and_helpers_with_optimized_code( project_root_path=self.project_root, ) unused_helpers = detect_unused_helper_functions(self.function_to_optimize, code_context, optimized_code) - # Revert unused helper functions to their original definitions if unused_helpers: revert_unused_helper_functions(self.project_root, unused_helpers, original_helper_code) diff --git a/tests/test_code_replacement.py b/tests/test_code_replacement.py index 7272163d3..f41cb7f69 100644 --- a/tests/test_code_replacement.py +++ b/tests/test_code_replacement.py @@ -88,6 +88,7 @@ class NewClass: def __init__(self, name): self.name = name def new_function(self, value): + totally_new_function(123) return self.name def new_function2(value): return value @@ -106,6 +107,7 @@ def new_function(self, value): def __init__(self, name): self.name = name def new_function(self, value): + totally_new_function(123) return self.name def new_function2(value): return value @@ -143,6 +145,7 @@ class NewClass: def __init__(self, name): self.name = name def new_function(self, value): + totally_new_function(123) return other_function(self.name) def new_function2(value): return value @@ -164,6 +167,7 @@ class NewClass: def __init__(self, name): self.name = name def new_function(self, value): + totally_new_function(123) return other_function(self.name) def new_function2(value): return value @@ -198,6 +202,7 @@ def totally_new_function(value): return value def other_function(st): + totally_new_function(123) return(st * 2) class NewClass: @@ -230,6 +235,7 @@ def yet_another_function(values): return len(values) def other_function(st): + totally_new_function(123) return(st * 2) def totally_new_function(value): @@ -259,6 +265,7 @@ def totally_new_function(value): return value def yet_another_function(values: Optional[str]): + totally_new_function(123) return len(values) + 2 def other_function(st): @@ -291,6 +298,7 @@ def other_function(st): print("Au revoir") def yet_another_function(values): + totally_new_function(123) return len(values) + 2 def other_function(st): @@ -730,6 +738,7 @@ def __init__(self, name): def __call__(self, value): return self.name def new_function2(value): + totally_new_function() return cst.ensure_type(value, str) """ @@ -750,6 +759,7 @@ def __init__(self, name): def __call__(self, value): return "I am still old" def new_function2(value): + totally_new_function() return cst.ensure_type(value, str) def totally_new_function(value: Optional[str]): @@ -3070,3 +3080,77 @@ def my_fixture(request): modified_module = module.visit(transformer) assert modified_module.code.strip() == expected.strip() + + +def test_new_global_created_helper_functions_scope(): + path_to_root = Path(__file__).resolve().parent.parent / "code_to_optimize" / "code_directories" / "unstructured_example" + optimized_code = (path_to_root / "optimized.py").read_text(encoding="utf-8") + + going_to_unlink=[] + code_path = (path_to_root / "base.py").resolve() + + temp_file = (path_to_root / "base_optimized.py") + temp_file.write_text(code_path.read_text(encoding="utf-8"), encoding="utf-8") + + going_to_unlink.append(temp_file) + tests_root = path_to_root / "tests" + + func = FunctionToOptimize(function_name="elements_from_dicts", parents=[], file_path=temp_file.resolve()) + test_config = TestConfig( + tests_root=tests_root, + tests_project_rootdir=tests_root, + project_root_path=path_to_root, + test_framework="pytest", + pytest_cmd="pytest", + ) + func_optimizer = FunctionOptimizer(function_to_optimize=func, test_cfg=test_config) + code_context: CodeOptimizationContext = func_optimizer.get_code_optimization_context().unwrap() + original_helper_code: dict[Path, str] = {} + + # helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + new_helper_functions: list[FakeFunctionSource] = [] + + for hf in code_context.helper_functions: + file_name = hf.file_path.name + temp_helper_file_path = str(hf.file_path).replace(file_name, f"temp_{file_name}") + Path(temp_helper_file_path).write_text(hf.file_path.read_text(encoding="utf-8"), encoding="utf-8") + going_to_unlink.append(Path(temp_helper_file_path)) + new_helper_functions.append(FakeFunctionSource( + file_path= Path(temp_helper_file_path), + fully_qualified_name= hf.fully_qualified_name, + jedi_definition= hf.jedi_definition, + only_function_name= hf.only_function_name, + source_code= hf.source_code, + qualified_name= hf.qualified_name + )) + + code_context.helper_functions = new_helper_functions + helper_function_paths = {hf.file_path for hf in code_context.helper_functions} + + for helper_function_path in helper_function_paths: + with helper_function_path.open(encoding="utf8") as f: + helper_code = f.read() + original_helper_code[helper_function_path] = helper_code + + func_optimizer.replace_function_and_helpers_with_optimized_code( + code_context=code_context, optimized_code=optimized_code, original_helper_code=original_helper_code + ) + final_output = temp_file.read_text(encoding="utf-8") + helper_elements_output = (path_to_root / "temp_elements.py").read_text(encoding="utf-8") + + # test rollingback changes + # func_optimizer.write_code_and_helpers( + # func_optimizer.function_to_optimize_source_code, + # original_helper_code, + # func_optimizer.function_to_optimize.file_path, + # ) + # assert code_path.read_text(encoding="utf-8") == temp_file.read_text(encoding="utf-8") + # TODO: assert no changes in the helpers also + + for temp_file in going_to_unlink: + temp_file.unlink(missing_ok=True) + + assert "def _extract_file_directory_and_name" not in final_output + assert "def _extract_file_directory_and_name" in helper_elements_output + +