Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ dependencies = [
"tomlkit~=0.13.0",
"wrapt~=1.17.0",
"chardet~=5.2.0",
"sarif-pydantic~=0.5.0",
"setuptools~=78.1",
]
keywords = ["codemod", "codemods", "security", "fix", "fixes"]
Expand Down
52 changes: 29 additions & 23 deletions src/codemodder/codeql.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,45 @@
import json
from pathlib import Path

from sarif_pydantic import Location as LocationModel
from sarif_pydantic import Result as ResultModel
from sarif_pydantic import Sarif
from typing_extensions import Self

from codemodder.result import LineInfo, ResultSet, SarifLocation, SarifResult
from codemodder.sarifs import AbstractSarifToolDetector
from codemodder.sarifs import AbstractSarifToolDetector, Run


class CodeQLSarifToolDetector(AbstractSarifToolDetector):
@classmethod
def detect(cls, run_data: dict) -> bool:
return "tool" in run_data and "CodeQL" in run_data["tool"]["driver"]["name"]
def detect(cls, run_data: Run) -> bool:
return "CodeQL" in run_data.tool.driver.name


class CodeQLLocation(SarifLocation):
@staticmethod
def get_snippet(sarif_location) -> str:
return ""

@classmethod
def from_sarif(cls, sarif_location) -> Self:
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
file = Path(artifact_location["uri"])
def from_sarif(cls, sarif_location: LocationModel) -> Self:
if (physical_location := sarif_location.physical_location) is None:
raise ValueError("Location does not contain a physicalLocation")
if (artifact_location := physical_location.artifact_location) is None:
raise ValueError("PhysicalLocation does not contain an artifactLocation")
if (uri := artifact_location.uri) is None:
raise ValueError("ArtifactLocation does not contain a uri")

file = Path(uri)

try:
region = sarif_location["physicalLocation"]["region"]
except KeyError:
if not (region := physical_location.region):
# A location without a region indicates a result for the entire file.
# Use sentinel values of 0 index for start/end
zero = LineInfo(0)
return cls(file=file, start=zero, end=zero)

start = LineInfo(line=region["startLine"], column=region.get("startColumn"))
if not region.start_line:
raise ValueError("Region does not contain a startLine")

start = LineInfo(line=region.start_line, column=region.start_column or -1)
end = LineInfo(
line=region.get("endLine", start.line),
column=region.get("endColumn", start.column),
line=region.end_line or start.line,
column=region.end_column or start.column,
)
return cls(file=file, start=start, end=end)

Expand All @@ -43,7 +48,7 @@ class CodeQLResult(SarifResult):
location_type = CodeQLLocation

@classmethod
def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str:
def rule_url_from_id(cls, result: ResultModel, run: Run, rule_id: str) -> str:
del result, run, rule_id
# TODO: Implement this method to return the specific rule URL
return "https://codeql.github.com/codeql-query-help/"
Expand All @@ -52,16 +57,17 @@ def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str:
class CodeQLResultSet(ResultSet):
@classmethod
def from_sarif(cls, sarif_file: str | Path, truncate_rule_id: bool = False) -> Self:
with open(sarif_file, "r", encoding="utf-8") as f:
data = json.load(f)
data = Sarif.model_validate_json(
Path(sarif_file).read_text(encoding="utf-8-sig")
)

result_set = cls()
for sarif_run in data["runs"]:
for sarif_run in data.runs:
if CodeQLSarifToolDetector.detect(sarif_run):
for sarif_result in sarif_run["results"]:
for sarif_result in sarif_run.results or []:
codeql_result = CodeQLResult.from_sarif(
sarif_result, sarif_run, truncate_rule_id
)
result_set.add_result(codeql_result)
result_set.store_tool_data(sarif_run.get("tool", {}))
result_set.store_tool_data(sarif_run.tool.model_dump())
return result_set
95 changes: 60 additions & 35 deletions src/codemodder/result.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import itertools
from abc import abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type

import libcst as cst
from boltons.setutils import IndexedSet
from libcst._position import CodeRange
from sarif_pydantic import Location as LocationModel
from sarif_pydantic import Result as ResultModel
from sarif_pydantic import Run
from typing_extensions import Self

from codemodder.codetf import Finding, Rule
Expand Down Expand Up @@ -36,23 +38,30 @@ class Location(ABCDataclass):
@dataclass(frozen=True)
class SarifLocation(Location):
@staticmethod
@abstractmethod
def get_snippet(sarif_location) -> str:
pass
def get_snippet(sarif_location: LocationModel) -> str | None:
return sarif_location.message.text if sarif_location.message else None

@classmethod
def from_sarif(cls, sarif_location) -> Self:
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
file = Path(artifact_location["uri"])
def from_sarif(cls, sarif_location: LocationModel) -> Self:
if not (physical_location := sarif_location.physical_location):
raise ValueError("Sarif location does not have a physical location")
if not (artifact_location := physical_location.artifact_location):
raise ValueError("Sarif location does not have an artifact location")
if not (region := physical_location.region):
raise ValueError("Sarif location does not have a region")
if not (uri := artifact_location.uri):
raise ValueError("Sarif location does not have a uri")

file = Path(uri)
snippet = cls.get_snippet(sarif_location)
start = LineInfo(
line=sarif_location["physicalLocation"]["region"]["startLine"],
column=sarif_location["physicalLocation"]["region"]["startColumn"],
line=region.start_line or -1,
column=region.start_column or -1,
snippet=snippet,
)
end = LineInfo(
line=sarif_location["physicalLocation"]["region"]["endLine"],
column=sarif_location["physicalLocation"]["region"]["endColumn"],
line=region.end_line or -1,
column=region.end_column or -1,
snippet=snippet,
)
return cls(file=file, start=start, end=end)
Expand Down Expand Up @@ -102,7 +111,7 @@ class SarifResult(SASTResult):

@classmethod
def from_sarif(
cls, sarif_result, sarif_run, truncate_rule_id: bool = False
cls, sarif_result: ResultModel, sarif_run: Run, truncate_rule_id: bool = False
) -> Self:
rule_id = cls.extract_rule_id(sarif_result, sarif_run, truncate_rule_id)
finding_id = cls.extract_finding_id(sarif_result) or rule_id
Expand All @@ -124,68 +133,84 @@ def from_sarif(
)

@classmethod
def extract_finding_message(cls, sarif_result: dict, sarif_run: dict) -> str | None:
return sarif_result.get("message", {}).get("text", None)
def extract_finding_message(
cls, sarif_result: ResultModel, sarif_run: Run
) -> str | None:
del sarif_run
return sarif_result.message.text

@classmethod
def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str | None:
def rule_url_from_id(
cls, result: ResultModel, run: Run, rule_id: str
) -> str | None:
del result, run, rule_id
return None

@classmethod
def extract_locations(cls, sarif_result) -> Sequence[Location]:
def extract_locations(cls, sarif_result: ResultModel) -> Sequence[Location]:
return tuple(
[
cls.location_type.from_sarif(location)
for location in sarif_result["locations"]
for location in sarif_result.locations or []
]
)

@classmethod
def extract_related_locations(cls, sarif_result) -> Sequence[LocationWithMessage]:
def extract_related_locations(
cls, sarif_result: ResultModel
) -> Sequence[LocationWithMessage]:
return tuple(
[
LocationWithMessage(
message=rel_location.get("message", {}).get("text", ""),
message=rel_location.message.text,
location=cls.location_type.from_sarif(rel_location),
)
for rel_location in sarif_result.get("relatedLocations", [])
for rel_location in sarif_result.related_locations or []
if rel_location.message
]
)

@classmethod
def extract_code_flows(cls, sarif_result) -> Sequence[Sequence[Location]]:
def extract_code_flows(
cls, sarif_result: ResultModel
) -> Sequence[Sequence[Location]]:
return tuple(
[
tuple(
[
cls.location_type.from_sarif(locations.get("location"))
for locations in threadflow.get("locations", {})
cls.location_type.from_sarif(locations.location)
for locations in threadflow.locations or []
if locations.location
]
)
for codeflow in sarif_result.get("codeFlows", {})
for threadflow in codeflow.get("threadFlows", {})
for codeflow in sarif_result.code_flows or []
for threadflow in codeflow.thread_flows or []
]
)

@classmethod
def extract_rule_id(cls, result, sarif_run, truncate_rule_id: bool = False) -> str:
if rule_id := result.get("ruleId"):
def extract_rule_id(
cls, result: ResultModel, sarif_run: Run, truncate_rule_id: bool = False
) -> str:
if rule_id := result.rule_id:
return rule_id.split(".")[-1] if truncate_rule_id else rule_id

# it may be contained in the 'rule' field through the tool component in the sarif file
if "rule" in result:
tool_index = result["rule"]["toolComponent"]["index"]
rule_index = result["rule"]["index"]
return sarif_run["tool"]["extensions"][tool_index]["rules"][rule_index][
"id"
]
if (
(rule := result.rule)
and sarif_run.tool.extensions
and rule.tool_component
and rule.tool_component.index is not None
):
tool_index = rule.tool_component.index
rule_index = rule.index
return sarif_run.tool.extensions[tool_index].rules[rule_index].id

raise ValueError("Could not extract rule id from sarif result.")

@classmethod
def extract_finding_id(cls, result) -> str | None:
return result.get("guid") or result.get("correlationGuid")
def extract_finding_id(cls, result: ResultModel) -> str | None:
return str(result.guid or "") or str(result.correlation_guid or "") or None


def same_line(pos: CodeRange, location: Location) -> bool:
Expand Down
24 changes: 12 additions & 12 deletions src/codemodder/sarifs.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
import json
from abc import ABCMeta, abstractmethod
from collections import defaultdict
from importlib.metadata import entry_points
from pathlib import Path
from typing import DefaultDict

from pydantic import ValidationError
from sarif_pydantic import Run, Sarif

from codemodder.logging import logger


class AbstractSarifToolDetector(metaclass=ABCMeta):
@classmethod
@abstractmethod
def detect(cls, run_data: dict) -> bool:
def detect(cls, run_data: Run) -> bool:
pass


Expand All @@ -24,18 +26,16 @@ def detect_sarif_tools(filenames: list[Path]) -> DefaultDict[str, list[Path]]:
}
for fname in filenames:
try:
data = json.loads(fname.read_text(encoding="utf-8-sig"))
except json.JSONDecodeError:
logger.exception("Malformed JSON file: %s", fname)
data = Sarif.model_validate_json(fname.read_text(encoding="utf-8-sig"))
except ValidationError:
logger.exception("Invalid SARIF file: %s", fname)
raise
for name, det in detectors.items():
try:
runs = data["runs"]
except KeyError:
logger.exception("Sarif file without `runs` data: %s", fname)
raise

for run in runs:
if not data.runs:
raise ValueError(f"SARIF file without `runs` data: {fname}")

for name, det in detectors.items():
for run in data.runs:
try:
if det.detect(run):
logger.debug("detected %s sarif: %s", name, fname)
Expand Down
Loading
Loading