7
7
import os
8
8
import math
9
9
import numpy as np
10
+ import attr
10
11
11
12
12
13
from modelspec .base_types import print_
@@ -160,12 +161,12 @@ def save_to_xml_file(info_dict, filename, indent=4, root="modelspec"):
160
161
indent (int, optional): The number of spaces used for indentation in the XML file.
161
162
Defaults to 4.
162
163
"""
163
- root = ET .Element (root )
164
+ # root = ET.Element(root)
164
165
165
- build_xml_element ( root , info_dict )
166
+ root = build_xml_element ( info_dict )
166
167
167
168
# Create an ElementTree object with the root element
168
- tree = ET .ElementTree (root )
169
+ # tree = ET.ElementTree(root)
169
170
170
171
# Generate the XML string
171
172
xml_str = ET .tostring (root , encoding = "utf-8" , method = "xml" ).decode ("utf-8" )
@@ -179,7 +180,7 @@ def save_to_xml_file(info_dict, filename, indent=4, root="modelspec"):
179
180
file .write (pretty_xml_str )
180
181
181
182
182
- def build_xml_element (parent , data ):
183
+ def build_xml_element (data , parent = None ):
183
184
"""
184
185
This recursively builds an XML element structure from a dictionary or a list.
185
186
@@ -188,22 +189,27 @@ def build_xml_element(parent, data):
188
189
data: The data to convert into XML elements.
189
190
190
191
Returns:
191
- None
192
+ Parent
192
193
"""
193
- if isinstance (data , dict ):
194
- for key , value in data .items ():
195
- if isinstance (value , dict ):
196
- element = ET .SubElement (parent , key .replace (" " , "_" ))
197
- build_xml_element (element , value )
198
- elif isinstance (value , list ):
199
- for item in value :
200
- subelement = ET .SubElement (parent , key .replace (" " , "_" ))
201
- build_xml_element (subelement , item )
202
- else :
203
- element = ET .SubElement (parent , key .replace (" " , "_" ))
204
- element .text = str (value )
205
- else :
206
- parent .text = str (data )
194
+ if parent is None :
195
+ parent = ET .Element (data .__class__ .__name__ )
196
+
197
+ attrs = attr .fields (data .__class__ )
198
+ for aattr in attrs :
199
+ if isinstance (aattr .default , attr .Factory ):
200
+ children = data .__getattribute__ (aattr .name )
201
+ if not isinstance (children , (list , tuple )):
202
+ children = [children ]
203
+
204
+ for child in children :
205
+ child_element = build_xml_element (child )
206
+ parent .append (child_element )
207
+ else :
208
+ attribute_name = aattr .name
209
+ attribute_value = data .__getattribute__ (aattr .name )
210
+ parent .set (attribute_name , str (attribute_value ))
211
+
212
+ return parent
207
213
208
214
209
215
def ascii_encode_dict (data ):
0 commit comments