Skip to content

Commit a1b2a54

Browse files
authored
Use sarif-pydantic for SARIF data models (#1034)
* Use Sarif data model for detection * Intermediate progress: update to sarif-pydantic 0.2.0 * Fix some tests * Bump sarif-pydantic dependency to 0.4.0 * Fix Sonar warnings * Bump sarif-pydantic to 0.5.0
1 parent b91dc9e commit a1b2a54

22 files changed

+254
-180
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ dependencies = [
2525
"tomlkit~=0.13.0",
2626
"wrapt~=1.17.0",
2727
"chardet~=5.2.0",
28+
"sarif-pydantic~=0.5.0",
2829
"setuptools~=78.1",
2930
]
3031
keywords = ["codemod", "codemods", "security", "fix", "fixes"]

src/codemodder/codeql.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,45 @@
1-
import json
21
from pathlib import Path
32

3+
from sarif_pydantic import Location as LocationModel
4+
from sarif_pydantic import Result as ResultModel
5+
from sarif_pydantic import Sarif
46
from typing_extensions import Self
57

68
from codemodder.result import LineInfo, ResultSet, SarifLocation, SarifResult
7-
from codemodder.sarifs import AbstractSarifToolDetector
9+
from codemodder.sarifs import AbstractSarifToolDetector, Run
810

911

1012
class CodeQLSarifToolDetector(AbstractSarifToolDetector):
1113
@classmethod
12-
def detect(cls, run_data: dict) -> bool:
13-
return "tool" in run_data and "CodeQL" in run_data["tool"]["driver"]["name"]
14+
def detect(cls, run_data: Run) -> bool:
15+
return "CodeQL" in run_data.tool.driver.name
1416

1517

1618
class CodeQLLocation(SarifLocation):
17-
@staticmethod
18-
def get_snippet(sarif_location) -> str:
19-
return ""
20-
2119
@classmethod
22-
def from_sarif(cls, sarif_location) -> Self:
23-
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
24-
file = Path(artifact_location["uri"])
20+
def from_sarif(cls, sarif_location: LocationModel) -> Self:
21+
if (physical_location := sarif_location.physical_location) is None:
22+
raise ValueError("Location does not contain a physicalLocation")
23+
if (artifact_location := physical_location.artifact_location) is None:
24+
raise ValueError("PhysicalLocation does not contain an artifactLocation")
25+
if (uri := artifact_location.uri) is None:
26+
raise ValueError("ArtifactLocation does not contain a uri")
27+
28+
file = Path(uri)
2529

26-
try:
27-
region = sarif_location["physicalLocation"]["region"]
28-
except KeyError:
30+
if not (region := physical_location.region):
2931
# A location without a region indicates a result for the entire file.
3032
# Use sentinel values of 0 index for start/end
3133
zero = LineInfo(0)
3234
return cls(file=file, start=zero, end=zero)
3335

34-
start = LineInfo(line=region["startLine"], column=region.get("startColumn"))
36+
if not region.start_line:
37+
raise ValueError("Region does not contain a startLine")
38+
39+
start = LineInfo(line=region.start_line, column=region.start_column or -1)
3540
end = LineInfo(
36-
line=region.get("endLine", start.line),
37-
column=region.get("endColumn", start.column),
41+
line=region.end_line or start.line,
42+
column=region.end_column or start.column,
3843
)
3944
return cls(file=file, start=start, end=end)
4045

@@ -43,7 +48,7 @@ class CodeQLResult(SarifResult):
4348
location_type = CodeQLLocation
4449

4550
@classmethod
46-
def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str:
51+
def rule_url_from_id(cls, result: ResultModel, run: Run, rule_id: str) -> str:
4752
del result, run, rule_id
4853
# TODO: Implement this method to return the specific rule URL
4954
return "https://codeql.github.com/codeql-query-help/"
@@ -52,16 +57,17 @@ def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str:
5257
class CodeQLResultSet(ResultSet):
5358
@classmethod
5459
def from_sarif(cls, sarif_file: str | Path, truncate_rule_id: bool = False) -> Self:
55-
with open(sarif_file, "r", encoding="utf-8") as f:
56-
data = json.load(f)
60+
data = Sarif.model_validate_json(
61+
Path(sarif_file).read_text(encoding="utf-8-sig")
62+
)
5763

5864
result_set = cls()
59-
for sarif_run in data["runs"]:
65+
for sarif_run in data.runs:
6066
if CodeQLSarifToolDetector.detect(sarif_run):
61-
for sarif_result in sarif_run["results"]:
67+
for sarif_result in sarif_run.results or []:
6268
codeql_result = CodeQLResult.from_sarif(
6369
sarif_result, sarif_run, truncate_rule_id
6470
)
6571
result_set.add_result(codeql_result)
66-
result_set.store_tool_data(sarif_run.get("tool", {}))
72+
result_set.store_tool_data(sarif_run.tool.model_dump())
6773
return result_set

src/codemodder/result.py

Lines changed: 60 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
from __future__ import annotations
22

33
import itertools
4-
from abc import abstractmethod
54
from dataclasses import dataclass, field
65
from pathlib import Path
76
from typing import TYPE_CHECKING, Any, ClassVar, Sequence, Type
87

98
import libcst as cst
109
from boltons.setutils import IndexedSet
1110
from libcst._position import CodeRange
11+
from sarif_pydantic import Location as LocationModel
12+
from sarif_pydantic import Result as ResultModel
13+
from sarif_pydantic import Run
1214
from typing_extensions import Self
1315

1416
from codemodder.codetf import Finding, Rule
@@ -36,23 +38,30 @@ class Location(ABCDataclass):
3638
@dataclass(frozen=True)
3739
class SarifLocation(Location):
3840
@staticmethod
39-
@abstractmethod
40-
def get_snippet(sarif_location) -> str:
41-
pass
41+
def get_snippet(sarif_location: LocationModel) -> str | None:
42+
return sarif_location.message.text if sarif_location.message else None
4243

4344
@classmethod
44-
def from_sarif(cls, sarif_location) -> Self:
45-
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
46-
file = Path(artifact_location["uri"])
45+
def from_sarif(cls, sarif_location: LocationModel) -> Self:
46+
if not (physical_location := sarif_location.physical_location):
47+
raise ValueError("Sarif location does not have a physical location")
48+
if not (artifact_location := physical_location.artifact_location):
49+
raise ValueError("Sarif location does not have an artifact location")
50+
if not (region := physical_location.region):
51+
raise ValueError("Sarif location does not have a region")
52+
if not (uri := artifact_location.uri):
53+
raise ValueError("Sarif location does not have a uri")
54+
55+
file = Path(uri)
4756
snippet = cls.get_snippet(sarif_location)
4857
start = LineInfo(
49-
line=sarif_location["physicalLocation"]["region"]["startLine"],
50-
column=sarif_location["physicalLocation"]["region"]["startColumn"],
58+
line=region.start_line or -1,
59+
column=region.start_column or -1,
5160
snippet=snippet,
5261
)
5362
end = LineInfo(
54-
line=sarif_location["physicalLocation"]["region"]["endLine"],
55-
column=sarif_location["physicalLocation"]["region"]["endColumn"],
63+
line=region.end_line or -1,
64+
column=region.end_column or -1,
5665
snippet=snippet,
5766
)
5867
return cls(file=file, start=start, end=end)
@@ -102,7 +111,7 @@ class SarifResult(SASTResult):
102111

103112
@classmethod
104113
def from_sarif(
105-
cls, sarif_result, sarif_run, truncate_rule_id: bool = False
114+
cls, sarif_result: ResultModel, sarif_run: Run, truncate_rule_id: bool = False
106115
) -> Self:
107116
rule_id = cls.extract_rule_id(sarif_result, sarif_run, truncate_rule_id)
108117
finding_id = cls.extract_finding_id(sarif_result) or rule_id
@@ -124,68 +133,84 @@ def from_sarif(
124133
)
125134

126135
@classmethod
127-
def extract_finding_message(cls, sarif_result: dict, sarif_run: dict) -> str | None:
128-
return sarif_result.get("message", {}).get("text", None)
136+
def extract_finding_message(
137+
cls, sarif_result: ResultModel, sarif_run: Run
138+
) -> str | None:
139+
del sarif_run
140+
return sarif_result.message.text
129141

130142
@classmethod
131-
def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str | None:
143+
def rule_url_from_id(
144+
cls, result: ResultModel, run: Run, rule_id: str
145+
) -> str | None:
132146
del result, run, rule_id
133147
return None
134148

135149
@classmethod
136-
def extract_locations(cls, sarif_result) -> Sequence[Location]:
150+
def extract_locations(cls, sarif_result: ResultModel) -> Sequence[Location]:
137151
return tuple(
138152
[
139153
cls.location_type.from_sarif(location)
140-
for location in sarif_result["locations"]
154+
for location in sarif_result.locations or []
141155
]
142156
)
143157

144158
@classmethod
145-
def extract_related_locations(cls, sarif_result) -> Sequence[LocationWithMessage]:
159+
def extract_related_locations(
160+
cls, sarif_result: ResultModel
161+
) -> Sequence[LocationWithMessage]:
146162
return tuple(
147163
[
148164
LocationWithMessage(
149-
message=rel_location.get("message", {}).get("text", ""),
165+
message=rel_location.message.text,
150166
location=cls.location_type.from_sarif(rel_location),
151167
)
152-
for rel_location in sarif_result.get("relatedLocations", [])
168+
for rel_location in sarif_result.related_locations or []
169+
if rel_location.message
153170
]
154171
)
155172

156173
@classmethod
157-
def extract_code_flows(cls, sarif_result) -> Sequence[Sequence[Location]]:
174+
def extract_code_flows(
175+
cls, sarif_result: ResultModel
176+
) -> Sequence[Sequence[Location]]:
158177
return tuple(
159178
[
160179
tuple(
161180
[
162-
cls.location_type.from_sarif(locations.get("location"))
163-
for locations in threadflow.get("locations", {})
181+
cls.location_type.from_sarif(locations.location)
182+
for locations in threadflow.locations or []
183+
if locations.location
164184
]
165185
)
166-
for codeflow in sarif_result.get("codeFlows", {})
167-
for threadflow in codeflow.get("threadFlows", {})
186+
for codeflow in sarif_result.code_flows or []
187+
for threadflow in codeflow.thread_flows or []
168188
]
169189
)
170190

171191
@classmethod
172-
def extract_rule_id(cls, result, sarif_run, truncate_rule_id: bool = False) -> str:
173-
if rule_id := result.get("ruleId"):
192+
def extract_rule_id(
193+
cls, result: ResultModel, sarif_run: Run, truncate_rule_id: bool = False
194+
) -> str:
195+
if rule_id := result.rule_id:
174196
return rule_id.split(".")[-1] if truncate_rule_id else rule_id
175197

176198
# it may be contained in the 'rule' field through the tool component in the sarif file
177-
if "rule" in result:
178-
tool_index = result["rule"]["toolComponent"]["index"]
179-
rule_index = result["rule"]["index"]
180-
return sarif_run["tool"]["extensions"][tool_index]["rules"][rule_index][
181-
"id"
182-
]
199+
if (
200+
(rule := result.rule)
201+
and sarif_run.tool.extensions
202+
and rule.tool_component
203+
and rule.tool_component.index is not None
204+
):
205+
tool_index = rule.tool_component.index
206+
rule_index = rule.index
207+
return sarif_run.tool.extensions[tool_index].rules[rule_index].id
183208

184209
raise ValueError("Could not extract rule id from sarif result.")
185210

186211
@classmethod
187-
def extract_finding_id(cls, result) -> str | None:
188-
return result.get("guid") or result.get("correlationGuid")
212+
def extract_finding_id(cls, result: ResultModel) -> str | None:
213+
return str(result.guid or "") or str(result.correlation_guid or "") or None
189214

190215

191216
def same_line(pos: CodeRange, location: Location) -> bool:

src/codemodder/sarifs.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
1-
import json
21
from abc import ABCMeta, abstractmethod
32
from collections import defaultdict
43
from importlib.metadata import entry_points
54
from pathlib import Path
65
from typing import DefaultDict
76

7+
from pydantic import ValidationError
8+
from sarif_pydantic import Run, Sarif
9+
810
from codemodder.logging import logger
911

1012

1113
class AbstractSarifToolDetector(metaclass=ABCMeta):
1214
@classmethod
1315
@abstractmethod
14-
def detect(cls, run_data: dict) -> bool:
16+
def detect(cls, run_data: Run) -> bool:
1517
pass
1618

1719

@@ -24,18 +26,16 @@ def detect_sarif_tools(filenames: list[Path]) -> DefaultDict[str, list[Path]]:
2426
}
2527
for fname in filenames:
2628
try:
27-
data = json.loads(fname.read_text(encoding="utf-8-sig"))
28-
except json.JSONDecodeError:
29-
logger.exception("Malformed JSON file: %s", fname)
29+
data = Sarif.model_validate_json(fname.read_text(encoding="utf-8-sig"))
30+
except ValidationError:
31+
logger.exception("Invalid SARIF file: %s", fname)
3032
raise
31-
for name, det in detectors.items():
32-
try:
33-
runs = data["runs"]
34-
except KeyError:
35-
logger.exception("Sarif file without `runs` data: %s", fname)
36-
raise
3733

38-
for run in runs:
34+
if not data.runs:
35+
raise ValueError(f"SARIF file without `runs` data: {fname}")
36+
37+
for name, det in detectors.items():
38+
for run in data.runs:
3939
try:
4040
if det.detect(run):
4141
logger.debug("detected %s sarif: %s", name, fname)

0 commit comments

Comments
 (0)