Skip to content

Commit 1b8e814

Browse files
authored
XML Transformer and Pipeline (#584)
* Initial implementation of XMLTransformer and Pipeline * Bugfixes and tests for XML transformer
1 parent 955c0dc commit 1b8e814

File tree

3 files changed

+272
-2
lines changed

3 files changed

+272
-2
lines changed

src/codemodder/codemods/codeql.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def apply(
2222

2323

2424
@cache
25-
def process_codeql_findings(semgrep_sarif_files: tuple[str]) -> ResultSet:
25+
def process_codeql_findings(codeql_sarif_files: tuple[str]) -> ResultSet:
2626
results = CodeQLResultSet()
27-
for file in semgrep_sarif_files or ():
27+
for file in codeql_sarif_files or ():
2828
results |= CodeQLResultSet.from_sarif(file)
2929
return results
Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
import mmap
2+
from tempfile import TemporaryFile
3+
from xml.sax import SAXParseException, handler
4+
from xml.sax.handler import LexicalHandler
5+
from xml.sax.saxutils import XMLGenerator
6+
from xml.sax.xmlreader import Locator
7+
8+
from defusedxml.sax import make_parser
9+
10+
from codemodder.codemods.base_transformer import BaseTransformerPipeline
11+
from codemodder.codetf import Change, ChangeSet
12+
from codemodder.context import CodemodExecutionContext
13+
from codemodder.diff import create_diff
14+
from codemodder.file_context import FileContext
15+
from codemodder.result import Result
16+
17+
18+
class XMLTransformer(XMLGenerator, LexicalHandler):
19+
"""
20+
Given a XML file, generates the same file but formatted.
21+
"""
22+
23+
change_description = ""
24+
25+
def __init__(
26+
self,
27+
out,
28+
encoding: str = "utf-8",
29+
short_empty_elements: bool = False,
30+
results: list[Result] | None = None,
31+
) -> None:
32+
self.results = results
33+
self.changes: list[Change] = []
34+
self._my_locator = Locator()
35+
super().__init__(out, encoding, short_empty_elements)
36+
37+
def startElement(self, name, attrs):
38+
super().startElement(name, attrs)
39+
40+
def endElement(self, name):
41+
super().endElement(name)
42+
43+
def characters(self, content):
44+
super().characters(content)
45+
46+
def skippedEntity(self, name: str) -> None:
47+
super().skippedEntity(name)
48+
49+
def comment(self, content: str):
50+
self._write(f"<!--{content}-->\n") # type: ignore
51+
52+
def startCDATA(self):
53+
self._write("<![CDATA[") # type: ignore
54+
55+
def endCDATA(self):
56+
self._write("]]>") # type: ignore
57+
58+
def startDTD(self, name: str, public_id: str | None, system_id: str | None):
59+
self._write(f'<!DOCTYPE {name} PUBLIC "{public_id}" "{system_id}">\n') # type: ignore
60+
return super().startDTD(name, public_id, system_id)
61+
62+
def endDTD(self) -> object:
63+
return super().endDTD()
64+
65+
def setDocumentLocator(self, locator: Locator) -> None:
66+
self._my_locator = locator
67+
68+
def event_match_result(self) -> bool:
69+
"""
70+
Returns True if the current event matches any result.
71+
"""
72+
line = self._my_locator.getLineNumber()
73+
column = self._my_locator.getColumnNumber()
74+
return self.match_result(line, column)
75+
76+
def match_result(self, line, column) -> bool:
77+
if self.results is None:
78+
return True
79+
for result in self.results or []:
80+
for location in result.locations:
81+
# No two elements can have the same start but different ends.
82+
# It suffices to only match the start.
83+
if location.start.line == line and location.start.column - 1 == column:
84+
return True
85+
return False
86+
87+
def add_change(self, line):
88+
self.changes.append(
89+
Change(lineNumber=line, description=self.change_description)
90+
)
91+
92+
93+
class ElementAttributeXMLTransformer(XMLTransformer):
94+
"""
95+
Changes the element and its attributes to the values provided in a given dict. For any attribute missing in the dict will stay the same as the original.
96+
"""
97+
98+
def __init__(
99+
self,
100+
out,
101+
name_attributes_map: dict[str, dict[str, str]],
102+
encoding: str = "utf-8",
103+
short_empty_elements: bool = False,
104+
results: list[Result] | None = None,
105+
) -> None:
106+
self.name_attributes_map = name_attributes_map
107+
super().__init__(out, encoding, short_empty_elements, results)
108+
109+
def startElement(self, name, attrs):
110+
new_attrs = attrs
111+
if self.event_match_result() and name in self.name_attributes_map:
112+
new_attrs = self.name_attributes_map[name]
113+
self.add_change(self._my_locator.getLineNumber())
114+
super().startElement(name, new_attrs)
115+
116+
117+
class XMLTransformerPipeline(BaseTransformerPipeline):
118+
119+
def __init__(self, xml_transformer: type[XMLTransformer]):
120+
super().__init__()
121+
self.xml_transformer = xml_transformer
122+
123+
def apply(
124+
self,
125+
context: CodemodExecutionContext,
126+
file_context: FileContext,
127+
results: list[Result] | None,
128+
) -> ChangeSet | None:
129+
if file_context.file_path.suffix.lower() not in (".config", ".xml"):
130+
return None
131+
132+
changes = []
133+
with TemporaryFile("w+") as output_file:
134+
135+
# this will fail fast for files that are not XML
136+
try:
137+
transformer_instance = self.xml_transformer(
138+
out=output_file, results=results
139+
)
140+
parser = make_parser()
141+
parser.setContentHandler(transformer_instance)
142+
parser.setProperty(
143+
handler.property_lexical_handler, transformer_instance
144+
)
145+
parser.parse(file_context.file_path)
146+
changes = transformer_instance.changes
147+
output_file.seek(0)
148+
149+
except SAXParseException:
150+
return None
151+
152+
diff = ""
153+
with open(file_context.file_path, "r") as original:
154+
# don't calculate diff if no changes were reported
155+
# TODO there's a failure potential here for very large files
156+
diff = (
157+
create_diff(
158+
original.readlines(),
159+
output_file.readlines(),
160+
)
161+
if changes
162+
else ""
163+
)
164+
165+
if not context.dry_run:
166+
with open(file_context.file_path, "w+b") as original:
167+
# mmap can't map empty files, write something first
168+
original.write(b"a")
169+
# copy contents of result into original file
170+
# the snippet below preserves the original file metadata and accounts for large files.
171+
output_file.seek(0)
172+
output_mmap = mmap.mmap(output_file.fileno(), 0)
173+
174+
original.truncate()
175+
original_mmap = mmap.mmap(original.fileno(), 0)
176+
original_mmap.resize(len(output_mmap))
177+
original_mmap[:] = output_mmap
178+
original_mmap.flush()
179+
180+
return ChangeSet(
181+
path=str(file_context.file_path.relative_to(context.directory)),
182+
diff=diff,
183+
changes=changes,
184+
)

tests/test_xml_transformer.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from io import StringIO
2+
from textwrap import dedent
3+
from xml.sax import handler
4+
5+
import pytest
6+
from defusedxml import ExternalReferenceForbidden
7+
from defusedxml.sax import make_parser
8+
9+
from codemodder.codemods.xml_transformer import (
10+
ElementAttributeXMLTransformer,
11+
XMLTransformer,
12+
)
13+
14+
15+
class TestXMLTransformer:
16+
17+
def run_and_assert(self, input_code, expected_output):
18+
with StringIO() as result, StringIO(dedent(input_code)) as input_stream:
19+
result = StringIO()
20+
transformer = XMLTransformer(result)
21+
parser = make_parser()
22+
parser.setContentHandler(transformer)
23+
parser.setProperty(handler.property_lexical_handler, transformer)
24+
parser.parse(input_stream)
25+
assert result.getvalue() == dedent(expected_output)
26+
27+
def test_parse_dtd_forbidden(self):
28+
input_code = """\
29+
<?xml version="1.0" encoding="utf-8"?>
30+
<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Transitional//EN" "http://www.w3.org/TR/xhtml1/DTD/xhtml1-transitional.dtd">
31+
"""
32+
expected_output = input_code
33+
with pytest.raises(ExternalReferenceForbidden):
34+
self.run_and_assert(input_code, expected_output)
35+
36+
def test_parse_comment(self):
37+
input_code = """\
38+
<?xml version="1.0" encoding="utf-8"?>
39+
<!-- comment -->
40+
<element></element>"""
41+
expected_output = input_code
42+
self.run_and_assert(input_code, expected_output)
43+
44+
def test_parse_cdata(self):
45+
input_code = """\
46+
<?xml version="1.0" encoding="utf-8"?>
47+
<element>
48+
<![CDATA[some characters]]>
49+
</element>"""
50+
expected_output = input_code
51+
self.run_and_assert(input_code, expected_output)
52+
53+
54+
class TestElementAttributeXMLTransformer:
55+
56+
def run_and_assert(self, name_attr_map, input_code, expected_output):
57+
with StringIO() as result, StringIO(dedent(input_code)) as input_stream:
58+
result = StringIO()
59+
transformer = ElementAttributeXMLTransformer(
60+
result, name_attributes_map=name_attr_map
61+
)
62+
parser = make_parser()
63+
parser.setContentHandler(transformer)
64+
parser.setProperty(handler.property_lexical_handler, transformer)
65+
parser.parse(input_stream)
66+
assert result.getvalue() == dedent(expected_output)
67+
68+
def test_change_single_attr(self):
69+
input_code = """\
70+
<?xml version="1.0" encoding="utf-8"?>
71+
<element attr="false"></element>"""
72+
expected_output = """\
73+
<?xml version="1.0" encoding="utf-8"?>
74+
<element attr="true"></element>"""
75+
name_attr_map = {"element": {"attr": "true"}}
76+
self.run_and_assert(name_attr_map, input_code, expected_output)
77+
78+
def test_change_multiple_attr(self):
79+
input_code = """\
80+
<?xml version="1.0" encoding="utf-8"?>
81+
<element first="1" second="2"></element>"""
82+
expected_output = """\
83+
<?xml version="1.0" encoding="utf-8"?>
84+
<element first="one" second="two"></element>"""
85+
name_attr_map = {"element": {"first": "one", "second": "two"}}
86+
self.run_and_assert(name_attr_map, input_code, expected_output)

0 commit comments

Comments
 (0)