Skip to content

Commit e2a20ce

Browse files
committed
Intermediate progress: update to sarif-pydantic 0.2.0
1 parent cf2d3fd commit e2a20ce

File tree

3 files changed

+41
-29
lines changed

3 files changed

+41
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ dependencies = [
2525
"tomlkit~=0.13.0",
2626
"wrapt~=1.17.0",
2727
"chardet~=5.2.0",
28-
"sarif-pydantic~=0.1.0",
28+
"sarif-pydantic~=0.2.0",
2929
"setuptools~=78.1",
3030
]
3131
keywords = ["codemod", "codemods", "security", "fix", "fixes"]

src/codemodder/codeql.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
import json
21
from pathlib import Path
32

3+
from sarif_pydantic import Sarif
44
from typing_extensions import Self
55

66
from codemodder.result import LineInfo, ResultSet, SarifLocation, SarifResult
@@ -52,16 +52,17 @@ def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str:
5252
class CodeQLResultSet(ResultSet):
5353
@classmethod
5454
def from_sarif(cls, sarif_file: str | Path, truncate_rule_id: bool = False) -> Self:
55-
with open(sarif_file, "r", encoding="utf-8") as f:
56-
data = json.load(f)
55+
data = Sarif.model_validate_json(
56+
Path(sarif_file).read_text(encoding="utf-8-sig")
57+
)
5758

5859
result_set = cls()
59-
for sarif_run in data["runs"]:
60+
for sarif_run in data.runs:
6061
if CodeQLSarifToolDetector.detect(sarif_run):
61-
for sarif_result in sarif_run["results"]:
62+
for sarif_result in sarif_run.results or []:
6263
codeql_result = CodeQLResult.from_sarif(
6364
sarif_result, sarif_run, truncate_rule_id
6465
)
6566
result_set.add_result(codeql_result)
66-
result_set.store_tool_data(sarif_run.get("tool", {}))
67+
result_set.store_tool_data(sarif_run.tool.model_dump())
6768
return result_set

src/codemodder/result.py

Lines changed: 33 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@
99
import libcst as cst
1010
from boltons.setutils import IndexedSet
1111
from libcst._position import CodeRange
12+
from sarif_pydantic import Location as LocationModel
13+
from sarif_pydantic import Result as ResultModel
14+
from sarif_pydantic import Run
1215
from typing_extensions import Self
1316

1417
from codemodder.codetf import Finding, Rule
@@ -41,7 +44,7 @@ def get_snippet(sarif_location) -> str:
4144
pass
4245

4346
@classmethod
44-
def from_sarif(cls, sarif_location) -> Self:
47+
def from_sarif(cls, sarif_location: LocationModel) -> Self:
4548
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
4649
file = Path(artifact_location["uri"])
4750
snippet = cls.get_snippet(sarif_location)
@@ -102,7 +105,7 @@ class SarifResult(SASTResult):
102105

103106
@classmethod
104107
def from_sarif(
105-
cls, sarif_result, sarif_run, truncate_rule_id: bool = False
108+
cls, sarif_result: ResultModel, sarif_run: Run, truncate_rule_id: bool = False
106109
) -> Self:
107110
rule_id = cls.extract_rule_id(sarif_result, sarif_run, truncate_rule_id)
108111
finding_id = cls.extract_finding_id(sarif_result) or rule_id
@@ -124,20 +127,25 @@ def from_sarif(
124127
)
125128

126129
@classmethod
127-
def extract_finding_message(cls, sarif_result: dict, sarif_run: dict) -> str | None:
128-
return sarif_result.get("message", {}).get("text", None)
130+
def extract_finding_message(
131+
cls, sarif_result: ResultModel, sarif_run: Run
132+
) -> str | None:
133+
del sarif_run
134+
return sarif_result.message.text
129135

130136
@classmethod
131-
def rule_url_from_id(cls, result: dict, run: dict, rule_id: str) -> str | None:
137+
def rule_url_from_id(
138+
cls, result: ResultModel, run: Run, rule_id: str
139+
) -> str | None:
132140
del result, run, rule_id
133141
return None
134142

135143
@classmethod
136-
def extract_locations(cls, sarif_result) -> Sequence[Location]:
144+
def extract_locations(cls, sarif_result: ResultModel) -> Sequence[Location]:
137145
return tuple(
138146
[
139147
cls.location_type.from_sarif(location)
140-
for location in sarif_result["locations"]
148+
for location in sarif_result.locations or []
141149
]
142150
)
143151

@@ -154,38 +162,41 @@ def extract_related_locations(cls, sarif_result) -> Sequence[LocationWithMessage
154162
)
155163

156164
@classmethod
157-
def extract_code_flows(cls, sarif_result) -> Sequence[Sequence[Location]]:
165+
def extract_code_flows(
166+
cls, sarif_result: ResultModel
167+
) -> Sequence[Sequence[Location]]:
158168
return tuple(
159169
[
160170
tuple(
161171
[
162-
cls.location_type.from_sarif(locations.get("location"))
163-
for locations in threadflow.get("locations", {})
172+
cls.location_type.from_sarif(locations.location)
173+
for locations in threadflow.locations or []
174+
if locations.location
164175
]
165176
)
166-
for codeflow in sarif_result.get("codeFlows", {})
167-
for threadflow in codeflow.get("threadFlows", {})
177+
for codeflow in sarif_result.code_flows or []
178+
for threadflow in codeflow.thread_flows or []
168179
]
169180
)
170181

171182
@classmethod
172-
def extract_rule_id(cls, result, sarif_run, truncate_rule_id: bool = False) -> str:
173-
if rule_id := result.get("ruleId"):
183+
def extract_rule_id(
184+
cls, result: ResultModel, sarif_run: Run, truncate_rule_id: bool = False
185+
) -> str:
186+
if rule_id := result.rule_id:
174187
return rule_id.split(".")[-1] if truncate_rule_id else rule_id
175188

176189
# it may be contained in the 'rule' field through the tool component in the sarif file
177-
if "rule" in result:
178-
tool_index = result["rule"]["toolComponent"]["index"]
179-
rule_index = result["rule"]["index"]
180-
return sarif_run["tool"]["extensions"][tool_index]["rules"][rule_index][
181-
"id"
182-
]
190+
if (rule := result.rule) and sarif_run.tool.extensions and rule.tool_component:
191+
tool_index = rule.tool_component.index
192+
rule_index = rule.index
193+
return sarif_run.tool.extensions[tool_index].rules[rule_index].id
183194

184195
raise ValueError("Could not extract rule id from sarif result.")
185196

186197
@classmethod
187-
def extract_finding_id(cls, result) -> str | None:
188-
return result.get("guid") or result.get("correlationGuid")
198+
def extract_finding_id(cls, result: ResultModel) -> str | None:
199+
return str(result.guid or "") or str(result.correlation_guid or "") or None
189200

190201

191202
def same_line(pos: CodeRange, location: Location) -> bool:

0 commit comments

Comments
 (0)