Skip to content

Commit 7baeabe

Browse files
authored
Xml transformer (#598)
* xml transformer can add new element * can add nested elements * cleanup
1 parent 8077b19 commit 7baeabe

File tree

2 files changed

+135
-1
lines changed

2 files changed

+135
-1
lines changed

src/codemodder/codemods/xml_transformer.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import mmap
2+
from dataclasses import dataclass, field
23
from tempfile import TemporaryFile
34
from xml.sax import SAXParseException, handler
45
from xml.sax.handler import LexicalHandler
@@ -114,6 +115,50 @@ def startElement(self, name, attrs):
114115
super().startElement(name, new_attrs)
115116

116117

118+
@dataclass
119+
class NewElement:
120+
name: str
121+
parent_name: str
122+
content: str = ""
123+
attributes: dict[str, str] = field(default_factory=dict)
124+
125+
126+
class NewElementXMLTransformer(XMLTransformer):
127+
"""
128+
Adds new elements to the XML file at specified locations.
129+
"""
130+
131+
def __init__(
132+
self,
133+
out,
134+
encoding: str = "utf-8",
135+
short_empty_elements: bool = False,
136+
results: list[Result] | None = None,
137+
new_elements: list[NewElement] | None = None,
138+
) -> None:
139+
super().__init__(out, encoding, short_empty_elements, results)
140+
self.new_elements = new_elements or []
141+
142+
def startElement(self, name, attrs):
143+
super().startElement(name, attrs)
144+
145+
def endElement(self, name):
146+
for new_element in self.new_elements:
147+
if new_element.parent_name == name:
148+
self.add_new_element(new_element)
149+
self.add_change(self._my_locator.getLineNumber())
150+
super().endElement(name)
151+
152+
def add_new_element(self, new_element: NewElement):
153+
attrs = AttributesImpl(new_element.attributes or {})
154+
super().startElement(new_element.name, attrs)
155+
if isinstance(new_element.content, NewElement):
156+
self.add_new_element(new_element.content)
157+
else:
158+
super().characters(new_element.content)
159+
super().endElement(new_element.name)
160+
161+
117162
class XMLTransformerPipeline(BaseTransformerPipeline):
118163

119164
def __init__(self, xml_transformer: type[XMLTransformer]):

tests/test_xml_transformer.py

Lines changed: 90 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88

99
from codemodder.codemods.xml_transformer import (
1010
ElementAttributeXMLTransformer,
11+
NewElement,
12+
NewElementXMLTransformer,
1113
XMLTransformer,
1214
)
1315

@@ -55,7 +57,6 @@ class TestElementAttributeXMLTransformer:
5557

5658
def run_and_assert(self, name_attr_map, input_code, expected_output):
5759
with StringIO() as result, StringIO(dedent(input_code)) as input_stream:
58-
result = StringIO()
5960
transformer = ElementAttributeXMLTransformer(
6061
result, name_attributes_map=name_attr_map
6162
)
@@ -94,3 +95,91 @@ def test_change_multiple_attr_and_preserve_existing(self):
9495
<element first="one" second="two" three="three"></element>"""
9596
name_attr_map = {"element": {"first": "one", "second": "two"}}
9697
self.run_and_assert(name_attr_map, input_code, expected_output)
98+
99+
100+
class TestNewElementXMLTransformer:
101+
102+
def run_and_assert(self, new_elements, input_code, expected_output):
103+
with StringIO() as result, StringIO(dedent(input_code)) as input_stream:
104+
transformer = NewElementXMLTransformer(result, new_elements=new_elements)
105+
parser = make_parser()
106+
parser.setContentHandler(transformer)
107+
parser.setProperty(handler.property_lexical_handler, transformer)
108+
parser.parse(input_stream)
109+
assert result.getvalue() == dedent(expected_output)
110+
111+
def test_add_new_element(self):
112+
input_code = """\
113+
<root></root>
114+
"""
115+
expected_output = """\
116+
<?xml version="1.0" encoding="utf-8"?>
117+
<root><child1></child1><child2 one="1">2</child2></root>"""
118+
new_elements = [
119+
NewElement(name="child1", parent_name="root"),
120+
NewElement(
121+
name="child2", parent_name="root", content="2", attributes={"one": "1"}
122+
),
123+
]
124+
self.run_and_assert(new_elements, input_code, expected_output)
125+
126+
def test_add_new_sibling_same_name(self):
127+
input_code = """\
128+
<root>
129+
<child>child 1</child>
130+
</root>
131+
"""
132+
expected_output = """\
133+
<?xml version="1.0" encoding="utf-8"?>
134+
<root>
135+
<child>child 1</child>
136+
<child>child 2</child></root>"""
137+
new_elements = [
138+
NewElement(name="child", parent_name="root", content="child 2"),
139+
]
140+
self.run_and_assert(new_elements, input_code, expected_output)
141+
142+
def test_add_nested_elementsl(self):
143+
input_code = """\
144+
<?xml version="1.0" encoding="utf-8" ?>
145+
<configuration>
146+
<system.web>
147+
</system.web>
148+
<system.webServer>
149+
<validation validateIntegratedModeConfiguration="false" />
150+
<modules>
151+
<remove name="ScriptModule" />
152+
<add name="ScriptModule" preCondition="managedHandler" type="System.Web.Handlers.ScriptModule, System.Web.Extensions, Version=3.5.0.0, Culture=neutral, PublicKeyToken=31BF3856AD364E35" />
153+
</modules>
154+
</system.webServer>
155+
</configuration>
156+
"""
157+
expected_output = """\
158+
<?xml version="1.0" encoding="utf-8"?>
159+
<configuration>
160+
<system.web>
161+
</system.web>
162+
<system.webServer>
163+
<validation validateIntegratedModeConfiguration="false"></validation>
164+
<modules>
165+
<remove name="ScriptModule"></remove>
166+
<add name="ScriptModule" preCondition="managedHandler" type="System.Web.Handlers.ScriptModule, System.Web.Extensions, Version=3.5.0.0, Culture=neutral, PublicKeyToken=31BF3856AD364E35"></add>
167+
</modules>
168+
<httpProtocol><customHeaders><add name="X-Frame-Options" value="DENY"></add></customHeaders></httpProtocol></system.webServer>
169+
</configuration>"""
170+
new_elements = [
171+
NewElement(
172+
name="httpProtocol",
173+
parent_name="system.webServer",
174+
content=NewElement(
175+
name="customHeaders",
176+
parent_name="httpProtocol",
177+
content=NewElement(
178+
name="add",
179+
parent_name="customHeaders",
180+
attributes={"name": "X-Frame-Options", "value": "DENY"},
181+
),
182+
),
183+
),
184+
]
185+
self.run_and_assert(new_elements, input_code, expected_output)

0 commit comments

Comments
 (0)