Skip to content

Commit 8cb84ac

Browse files
authored
Refactor Sarif Result (#583)
* create base class sarif result * more refactor * create sarif loc
1 parent ef68451 commit 8cb84ac

File tree

3 files changed

+78
-101
lines changed

3 files changed

+78
-101
lines changed

src/codemodder/codeql.py

Lines changed: 5 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from typing_extensions import Self
55

6-
from codemodder.result import LineInfo, Location, LocationWithMessage, Result, ResultSet
6+
from codemodder.result import LineInfo, ResultSet, SarifLocation, SarifResult
77
from codemodder.sarifs import AbstractSarifToolDetector
88

99

@@ -13,7 +13,7 @@ def detect(cls, run_data: dict) -> bool:
1313
return "tool" in run_data and "CodeQL" in run_data["tool"]["driver"]["name"]
1414

1515

16-
class CodeQLLocation(Location):
16+
class CodeQLLocation(SarifLocation):
1717
@classmethod
1818
def from_sarif(cls, sarif_location) -> Self:
1919
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
@@ -35,46 +35,8 @@ def from_sarif(cls, sarif_location) -> Self:
3535
return cls(file=file, start=start, end=end)
3636

3737

38-
class CodeQLResult(Result):
39-
@classmethod
40-
def from_sarif(
41-
cls, sarif_result, sarif_run, rule_extensions, truncate_rule_id: bool = False
42-
) -> Self:
43-
extension_index = sarif_result["rule"]["toolComponent"]["index"]
44-
tool_index = sarif_result["rule"]["index"]
45-
rule_data = rule_extensions[extension_index]["rules"][tool_index]
46-
47-
locations: list[Location] = []
48-
for location in sarif_result["locations"]:
49-
try:
50-
codeql_location = CodeQLLocation.from_sarif(location)
51-
except KeyError:
52-
continue
53-
54-
locations.append(codeql_location)
55-
all_flows: list[list[Location]] = [
56-
[
57-
CodeQLLocation.from_sarif(locations.get("location"))
58-
for locations in threadflow.get("locations", {})
59-
]
60-
for codeflow in sarif_result.get("codeFlows", {})
61-
for threadflow in codeflow.get("threadFlows", {})
62-
]
63-
related_locations: list[LocationWithMessage] = []
64-
if "relatedLocations" in sarif_result:
65-
related_locations = [
66-
LocationWithMessage(
67-
message=rel_location.get("message", {}).get("text", ""),
68-
location=CodeQLLocation.from_sarif(rel_location),
69-
)
70-
for rel_location in sarif_result.get("relatedLocations", [])
71-
]
72-
return cls(
73-
rule_id=rule_data["id"],
74-
locations=locations,
75-
codeflows=all_flows,
76-
related_locations=related_locations,
77-
)
38+
class CodeQLResult(SarifResult):
39+
location_type = CodeQLLocation
7840

7941

8042
class CodeQLResultSet(ResultSet):
@@ -85,11 +47,10 @@ def from_sarif(cls, sarif_file: str | Path, truncate_rule_id: bool = False) -> S
8547

8648
result_set = cls()
8749
for sarif_run in data["runs"]:
88-
rule_extensions = sarif_run["tool"]["extensions"]
8950
if CodeQLSarifToolDetector.detect(sarif_run):
9051
for sarif_result in sarif_run["results"]:
9152
codeql_result = CodeQLResult.from_sarif(
92-
sarif_result, sarif_run, rule_extensions, truncate_rule_id
53+
sarif_result, sarif_run, truncate_rule_id
9354
)
9455
result_set.add_result(codeql_result)
9556
return result_set

src/codemodder/result.py

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
from __future__ import annotations
22

3+
from abc import abstractmethod
34
from dataclasses import dataclass, field
45
from pathlib import Path
5-
from typing import TYPE_CHECKING, Any
6+
from typing import TYPE_CHECKING, Any, ClassVar, Type
67

78
import libcst as cst
89
from libcst._position import CodeRange
10+
from typing_extensions import Self
911

1012
from codemodder.codetf import Finding
1113

@@ -29,6 +31,13 @@ class Location(ABCDataclass):
2931
end: LineInfo
3032

3133

34+
class SarifLocation(Location):
35+
@classmethod
36+
@abstractmethod
37+
def from_sarif(cls, sarif_location) -> Self:
38+
pass
39+
40+
3241
@dataclass
3342
class LocationWithMessage:
3443
location: Location
@@ -58,6 +67,65 @@ def match_location(self, pos: CodeRange, node: cst.CSTNode) -> bool:
5867
)
5968

6069

70+
@dataclass(kw_only=True)
71+
class SarifResult(Result, ABCDataclass):
72+
location_type: ClassVar[Type[SarifLocation]]
73+
74+
@classmethod
75+
def from_sarif(
76+
cls, sarif_result, sarif_run, truncate_rule_id: bool = False
77+
) -> Self:
78+
return cls(
79+
rule_id=cls.extract_rule_id(sarif_result, sarif_run, truncate_rule_id),
80+
locations=cls.extract_locations(sarif_result),
81+
codeflows=cls.extract_code_flows(sarif_result),
82+
related_locations=cls.extract_related_locations(sarif_result),
83+
)
84+
85+
@classmethod
86+
def extract_locations(cls, sarif_result) -> list[Location]:
87+
return [
88+
cls.location_type.from_sarif(location)
89+
for location in sarif_result["locations"]
90+
]
91+
92+
@classmethod
93+
def extract_related_locations(cls, sarif_result) -> list[LocationWithMessage]:
94+
return [
95+
LocationWithMessage(
96+
message=rel_location.get("message", {}).get("text", ""),
97+
location=cls.location_type.from_sarif(rel_location),
98+
)
99+
for rel_location in sarif_result.get("relatedLocations", [])
100+
]
101+
102+
@classmethod
103+
def extract_code_flows(cls, sarif_result) -> list[list[Location]]:
104+
return [
105+
[
106+
cls.location_type.from_sarif(locations.get("location"))
107+
for locations in threadflow.get("locations", {})
108+
]
109+
for codeflow in sarif_result.get("codeFlows", {})
110+
for threadflow in codeflow.get("threadFlows", {})
111+
]
112+
113+
@classmethod
114+
def extract_rule_id(cls, result, sarif_run, truncate_rule_id: bool = False) -> str:
115+
if rule_id := result.get("ruleId"):
116+
return rule_id.split(".")[-1] if truncate_rule_id else rule_id
117+
118+
# it may be contained in the 'rule' field through the tool component in the sarif file
119+
if "rule" in result:
120+
tool_index = result["rule"]["toolComponent"]["index"]
121+
rule_index = result["rule"]["index"]
122+
return sarif_run["tool"]["extensions"][tool_index]["rules"][rule_index][
123+
"id"
124+
]
125+
126+
raise ValueError("Could not extract rule id from sarif result.")
127+
128+
61129
@dataclass(kw_only=True)
62130
class SASTResult(Result):
63131
finding_id: str

src/codemodder/semgrep.py

Lines changed: 4 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from codemodder.context import CodemodExecutionContext
1111
from codemodder.logging import logger
12-
from codemodder.result import LineInfo, Location, LocationWithMessage, Result, ResultSet
12+
from codemodder.result import LineInfo, Result, ResultSet, SarifLocation, SarifResult
1313
from codemodder.sarifs import AbstractSarifToolDetector
1414

1515

@@ -22,7 +22,7 @@ def detect(cls, run_data: dict) -> bool:
2222
)
2323

2424

25-
class SemgrepLocation(Location):
25+
class SemgrepLocation(SarifLocation):
2626
@classmethod
2727
def from_sarif(cls, sarif_location) -> Self:
2828
artifact_location = sarif_location["physicalLocation"]["artifactLocation"]
@@ -40,60 +40,8 @@ def from_sarif(cls, sarif_location) -> Self:
4040
return cls(file=file, start=start, end=end)
4141

4242

43-
class SemgrepResult(Result):
44-
@classmethod
45-
def extract_rule_id(
46-
cls, result, sarif_run, truncate_rule_id: bool = False
47-
) -> Optional[str]:
48-
if rule_id := result.get("ruleId"):
49-
return rule_id.split(".")[-1] if truncate_rule_id else rule_id
50-
51-
# it may be contained in the 'rule' field through the tool component in the sarif file
52-
if "rule" in result:
53-
tool_index = result["rule"]["toolComponent"]["index"]
54-
rule_index = result["rule"]["index"]
55-
return sarif_run["tool"]["extensions"][tool_index]["rules"][rule_index][
56-
"id"
57-
]
58-
59-
return None
60-
61-
@classmethod
62-
def from_sarif(
63-
cls, sarif_result, sarif_run, truncate_rule_id: bool = False
64-
) -> Self:
65-
if not (
66-
rule_id := cls.extract_rule_id(sarif_result, sarif_run, truncate_rule_id)
67-
):
68-
raise ValueError("Could not extract rule id from sarif result.")
69-
70-
locations: list[Location] = []
71-
for location in sarif_result["locations"]:
72-
artifact_location = SemgrepLocation.from_sarif(location)
73-
locations.append(artifact_location)
74-
all_flows: list[list[Location]] = [
75-
[
76-
SemgrepLocation.from_sarif(locations.get("location"))
77-
for locations in threadflow.get("locations", {})
78-
]
79-
for codeflow in sarif_result.get("codeFlows", {})
80-
for threadflow in codeflow.get("threadFlows", {})
81-
]
82-
related_locations: list[LocationWithMessage] = []
83-
if "relatedLocations" in sarif_result:
84-
related_locations = [
85-
LocationWithMessage(
86-
message=rel_location.get("message", {}).get("text", ""),
87-
location=SemgrepLocation.from_sarif(rel_location),
88-
)
89-
for rel_location in sarif_result.get("relatedLocations", [])
90-
]
91-
return cls(
92-
rule_id=rule_id,
93-
locations=locations,
94-
codeflows=all_flows,
95-
related_locations=related_locations,
96-
)
43+
class SemgrepResult(SarifResult):
44+
location_type = SemgrepLocation
9745

9846

9947
class SemgrepResultSet(ResultSet):

0 commit comments

Comments
 (0)