Skip to content

Commit cf2d3fd

Browse files
committed
Use Sarif data model for detection
1 parent b91dc9e commit cf2d3fd

File tree

5 files changed

+24
-25
lines changed

5 files changed

+24
-25
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"tomlkit~=0.13.0",
2626
"wrapt~=1.17.0",
2727
"chardet~=5.2.0",
28+
"sarif-pydantic~=0.1.0",
2829
"setuptools~=78.1",
2930
]
3031
keywords = ["codemod", "codemods", "security", "fix", "fixes"]

src/codemodder/codeql.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
from typing_extensions import Self
55

66
from codemodder.result import LineInfo, ResultSet, SarifLocation, SarifResult
7-
from codemodder.sarifs import AbstractSarifToolDetector
7+
from codemodder.sarifs import AbstractSarifToolDetector, Run
88

99

1010
class CodeQLSarifToolDetector(AbstractSarifToolDetector):
1111
@classmethod
12-
def detect(cls, run_data: dict) -> bool:
13-
return "tool" in run_data and "CodeQL" in run_data["tool"]["driver"]["name"]
12+
def detect(cls, run_data: Run) -> bool:
13+
return "CodeQL" in run_data.tool.driver.name
1414

1515

1616
class CodeQLLocation(SarifLocation):

src/codemodder/sarifs.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
import json
21
from abc import ABCMeta, abstractmethod
32
from collections import defaultdict
43
from importlib.metadata import entry_points
54
from pathlib import Path
65
from typing import DefaultDict
76

7+
from pydantic import ValidationError
8+
from sarif_pydantic import Run, Sarif
9+
810
from codemodder.logging import logger
911

1012

1113
class AbstractSarifToolDetector(metaclass=ABCMeta):
1214
@classmethod
1315
@abstractmethod
14-
def detect(cls, run_data: dict) -> bool:
16+
def detect(cls, run_data: Run) -> bool:
1517
pass
1618

1719

@@ -24,18 +26,16 @@ def detect_sarif_tools(filenames: list[Path]) -> DefaultDict[str, list[Path]]:
2426
}
2527
for fname in filenames:
2628
try:
27-
data = json.loads(fname.read_text(encoding="utf-8-sig"))
28-
except json.JSONDecodeError:
29-
logger.exception("Malformed JSON file: %s", fname)
29+
data = Sarif.model_validate_json(fname.read_text(encoding="utf-8-sig"))
30+
except ValidationError:
31+
logger.exception("Invalid SARIF file: %s", fname)
3032
raise
31-
for name, det in detectors.items():
32-
try:
33-
runs = data["runs"]
34-
except KeyError:
35-
logger.exception("Sarif file without `runs` data: %s", fname)
36-
raise
3733

38-
for run in runs:
34+
if not data.runs:
35+
raise ValueError(f"SARIF file without `runs` data: {fname}")
36+
37+
for name, det in detectors.items():
38+
for run in data.runs:
3939
try:
4040
if det.detect(run):
4141
logger.debug("detected %s sarif: %s", name, fname)

src/codemodder/semgrep.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,13 @@
1010
from codemodder.context import CodemodExecutionContext
1111
from codemodder.logging import logger
1212
from codemodder.result import Result, ResultSet, SarifLocation, SarifResult
13-
from codemodder.sarifs import AbstractSarifToolDetector
13+
from codemodder.sarifs import AbstractSarifToolDetector, Run
1414

1515

1616
class SemgrepSarifToolDetector(AbstractSarifToolDetector):
1717
@classmethod
18-
def detect(cls, run_data: dict) -> bool:
19-
return (
20-
"tool" in run_data
21-
and "semgrep" in run_data["tool"]["driver"]["name"].lower()
22-
)
18+
def detect(cls, run_data: Run) -> bool:
19+
return "semgrep" in run_data.tool.driver.name
2320

2421

2522
class SemgrepLocation(SarifLocation):

tests/test_sarif_processing.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from pathlib import Path
44

55
import pytest
6+
from pydantic import ValidationError
67

78
from codemodder.codemods.semgrep import process_semgrep_findings
89
from codemodder.sarifs import detect_sarif_tools
@@ -119,9 +120,9 @@ def test_bad_sarif(self, tmpdir, caplog):
119120
# remove all { to make a badly formatted json
120121
f.write(sarif_file.read_text(encoding="utf-8").replace("{", ""))
121122

122-
with pytest.raises(json.JSONDecodeError):
123+
with pytest.raises(ValidationError):
123124
detect_sarif_tools([bad_json])
124-
assert f"Malformed JSON file: {str(bad_json)}" in caplog.text
125+
assert f"Invalid SARIF file: {str(bad_json)}" in caplog.text
125126

126127
def test_bad_sarif_no_runs_data(self, tmpdir, caplog):
127128
bad_json = tmpdir / "bad.sarif"
@@ -134,9 +135,9 @@ def test_bad_sarif_no_runs_data(self, tmpdir, caplog):
134135
with open(bad_json, "w") as f:
135136
f.write(data)
136137

137-
with pytest.raises(KeyError):
138+
with pytest.raises(ValidationError):
138139
detect_sarif_tools([bad_json])
139-
assert f"Sarif file without `runs` data: {str(bad_json)}" in caplog.text
140+
assert f"Invalid SARIF file: {str(bad_json)}" in caplog.text
140141

141142
def test_two_sarifs_different_tools(self):
142143
results = detect_sarif_tools(

0 commit comments

Comments
 (0)