Skip to content

Commit 9090b4b

Browse files
added changes to the build_xml_element and save_to_xml_file
1 parent 34ef6cd commit 9090b4b

File tree

1 file changed

+25
-19
lines changed

1 file changed

+25
-19
lines changed

src/modelspec/utils.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import math
99
import numpy as np
10+
import attr
1011

1112

1213
from modelspec.base_types import print_
@@ -160,12 +161,12 @@ def save_to_xml_file(info_dict, filename, indent=4, root="modelspec"):
160161
indent (int, optional): The number of spaces used for indentation in the XML file.
161162
Defaults to 4.
162163
"""
163-
root = ET.Element(root)
164+
# root = ET.Element(root)
164165

165-
build_xml_element(root, info_dict)
166+
root = build_xml_element(info_dict)
166167

167168
# Create an ElementTree object with the root element
168-
tree = ET.ElementTree(root)
169+
# tree = ET.ElementTree(root)
169170

170171
# Generate the XML string
171172
xml_str = ET.tostring(root, encoding="utf-8", method="xml").decode("utf-8")
@@ -179,7 +180,7 @@ def save_to_xml_file(info_dict, filename, indent=4, root="modelspec"):
179180
file.write(pretty_xml_str)
180181

181182

182-
def build_xml_element(parent, data):
183+
def build_xml_element(data, parent=None):
183184
"""
184185
This recursively builds an XML element structure from a dictionary or a list.
185186
@@ -188,22 +189,27 @@ def build_xml_element(parent, data):
188189
data: The data to convert into XML elements.
189190
190191
Returns:
191-
None
192+
Parent
192193
"""
193-
if isinstance(data, dict):
194-
for key, value in data.items():
195-
if isinstance(value, dict):
196-
element = ET.SubElement(parent, key.replace(" ", "_"))
197-
build_xml_element(element, value)
198-
elif isinstance(value, list):
199-
for item in value:
200-
subelement = ET.SubElement(parent, key.replace(" ", "_"))
201-
build_xml_element(subelement, item)
202-
else:
203-
element = ET.SubElement(parent, key.replace(" ", "_"))
204-
element.text = str(value)
205-
else:
206-
parent.text = str(data)
194+
if parent is None:
195+
parent = ET.Element(data.__class__.__name__)
196+
197+
attrs = attr.fields(data.__class__)
198+
for aattr in attrs:
199+
if isinstance(aattr.default, attr.Factory):
200+
children = data.__getattribute__(aattr.name)
201+
if not isinstance(children, (list, tuple)):
202+
children = [children]
203+
204+
for child in children:
205+
child_element = build_xml_element(child)
206+
parent.append(child_element)
207+
else:
208+
attribute_name = aattr.name
209+
attribute_value = data.__getattribute__(aattr.name)
210+
parent.set(attribute_name, str(attribute_value))
211+
212+
return parent
207213

208214

209215
def ascii_encode_dict(data):

0 commit comments

Comments
 (0)