diff --git a/src/codemodder/codeql.py b/src/codemodder/codeql.py index 711a00ff..80c9c03c 100644 --- a/src/codemodder/codeql.py +++ b/src/codemodder/codeql.py @@ -17,7 +17,8 @@ def detect(cls, run_data: Run) -> bool: class CodeQLLocation(SarifLocation): @classmethod - def from_sarif(cls, sarif_location: LocationModel) -> Self: + def from_sarif(cls, run: Run, sarif_location: LocationModel) -> Self: + del run if (physical_location := sarif_location.physical_location) is None: raise ValueError("Location does not contain a physicalLocation") if (artifact_location := physical_location.artifact_location) is None: diff --git a/src/codemodder/result.py b/src/codemodder/result.py index 023683ec..e6b473de 100644 --- a/src/codemodder/result.py +++ b/src/codemodder/result.py @@ -3,7 +3,7 @@ import itertools from dataclasses import dataclass, field from pathlib import Path -from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type +from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type, TypeVar import libcst as cst from boltons.setutils import IndexedSet @@ -41,8 +41,13 @@ class SarifLocation(Location): def get_snippet(sarif_location: LocationModel) -> str | None: return sarif_location.message.text if sarif_location.message else None + @staticmethod + def process_uri(run: Run, sarif_location: LocationModel, uri: str) -> Path: + del sarif_location + return Path(uri) + @classmethod - def from_sarif(cls, sarif_location: LocationModel) -> Self: + def from_sarif(cls, run: Run, sarif_location: LocationModel) -> Self: if not (physical_location := sarif_location.physical_location): raise ValueError("Sarif location does not have a physical location") if not (artifact_location := physical_location.artifact_location): @@ -52,7 +57,7 @@ def from_sarif(cls, sarif_location: LocationModel) -> Self: if not (uri := artifact_location.uri): raise ValueError("Sarif location does not have a uri") - file = Path(uri) + file = cls.process_uri(run, sarif_location, uri) snippet = cls.get_snippet(sarif_location) start = LineInfo( line=region.start_line or -1, @@ -117,9 +122,9 @@ def from_sarif( finding_id = cls.extract_finding_id(sarif_result) or rule_id return cls( rule_id=rule_id, - locations=cls.extract_locations(sarif_result), - codeflows=cls.extract_code_flows(sarif_result), - related_locations=cls.extract_related_locations(sarif_result), + locations=cls.extract_locations(sarif_result, sarif_run), + codeflows=cls.extract_code_flows(sarif_result, sarif_run), + related_locations=cls.extract_related_locations(sarif_result, sarif_run), finding_id=finding_id, finding=Finding( id=finding_id, @@ -147,23 +152,25 @@ def rule_url_from_id( return None @classmethod - def extract_locations(cls, sarif_result: ResultModel) -> Sequence[Location]: + def extract_locations( + cls, sarif_result: ResultModel, run: Run + ) -> Sequence[Location]: return tuple( [ - cls.location_type.from_sarif(location) + cls.location_type.from_sarif(run, location) for location in sarif_result.locations or [] ] ) @classmethod def extract_related_locations( - cls, sarif_result: ResultModel + cls, sarif_result: ResultModel, run: Run ) -> Sequence[LocationWithMessage]: return tuple( [ LocationWithMessage( message=rel_location.message.text, - location=cls.location_type.from_sarif(rel_location), + location=cls.location_type.from_sarif(run, rel_location), ) for rel_location in sarif_result.related_locations or [] if rel_location.message @@ -172,13 +179,13 @@ def extract_related_locations( @classmethod def extract_code_flows( - cls, sarif_result: ResultModel + cls, sarif_result: ResultModel, run: Run ) -> Sequence[Sequence[Location]]: return tuple( [ tuple( [ - cls.location_type.from_sarif(locations.location) + cls.location_type.from_sarif(run, locations.location) for locations in threadflow.locations or [] if locations.location ] @@ -225,8 +232,11 @@ def fuzzy_column_match(pos: CodeRange, location: Location) -> bool: ) -class ResultSet(dict[str, dict[Path, list[Result]]]): - results_for_rule: dict[str, list[Result]] +ResultType = TypeVar("ResultType", bound=Result) + + +class ResultSet(dict[str, dict[Path, list[ResultType]]]): + results_for_rule: dict[str, list[ResultType]] # stores SARIF runs.tool data tools: list[dict[str, dict]] @@ -235,7 +245,7 @@ def __init__(self, *args, **kwargs): self.results_for_rule = {} self.tools = [] - def add_result(self, result: Result): + def add_result(self, result: ResultType): self.results_for_rule.setdefault(result.rule_id, []).append(result) for loc in result.locations: self.setdefault(result.rule_id, {}).setdefault(loc.file, []).append(result) @@ -246,7 +256,7 @@ def store_tool_data(self, tool_data: dict): def results_for_rule_and_file( self, context: CodemodExecutionContext, rule_id: str, file: Path - ) -> list[Result]: + ) -> list[ResultType]: """ Return list of results for a given rule and file. @@ -258,7 +268,7 @@ def results_for_rule_and_file( """ return self.get(rule_id, {}).get(file.relative_to(context.directory), []) - def results_for_rules(self, rule_ids: list[str]) -> list[Result]: + def results_for_rules(self, rule_ids: list[str]) -> list[ResultType]: """ Returns flat list of all results that match any of the given rule IDs. """