Skip to content

Commit 4c59470

Browse files
modified test_base.py file to test xml serialization
1 parent bb5a17d commit 4c59470

File tree

1 file changed

+51
-4
lines changed

1 file changed

+51
-4
lines changed

tests/test_base.py

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import List, Dict, Any, Optional
88

99
import sys
10+
from pathlib import Path
1011

1112
# Some test modelspec classes to use in the tests
1213

@@ -154,13 +155,17 @@ def test_save_load_json(tmp_path):
154155

155156
str_orig = str(net)
156157

157-
filenamej = str(tmp_path / f"{net.id}.json")
158+
filenamej = str(Path(tmp_path) / f"{net.id}.json")
158159
net.to_json_file(filenamej)
159160

160-
filenamey = str(tmp_path / f"{net.id}.yaml")
161+
filenamey = str(Path(tmp_path) / f"{net.id}.yaml")
161162
# net.id = net.id+'_yaml'
162163
net.to_yaml_file(filenamey)
163-
from modelspec.utils import load_json, load_yaml
164+
165+
filenamex = str(Path(tmp_path) / f"{net.id}.xml")
166+
net.to_xml_file(filenamex)
167+
168+
from modelspec.utils import load_json, load_yaml, load_xml
164169

165170
dataj = load_json(filenamej)
166171
print_v("Loaded network specification from %s" % filenamej)
@@ -174,12 +179,20 @@ def test_save_load_json(tmp_path):
174179
nety = NewNetwork.from_dict(datay)
175180
str_nety = str(nety)
176181

182+
datax = load_xml(filenamex)
183+
print_v("Loaded network specification from %s" % filenamex)
184+
185+
netx = NewNetwork.from_dict(datax)
186+
str_netx = str(netx)
187+
177188
print("----- Before -----")
178189
print(str_orig)
179190
print("----- After via %s -----" % filenamej)
180191
print(str_netj)
181192
print("----- After via %s -----" % filenamey)
182193
print(str_nety)
194+
print("----- After via %s -----" % filenamex)
195+
print(str_netx)
183196

184197
print("Test JSON..")
185198
if sys.version_info[0] == 2:
@@ -197,10 +210,19 @@ def test_save_load_json(tmp_path):
197210
else:
198211
assert str_orig == str_nety
199212

213+
print("Test XML..")
214+
if sys.version_info[0] == 2:
215+
assert len(str_orig) == len(
216+
str_netx
217+
) # Order not preserved in py2, just test len
218+
else:
219+
assert str_orig == str_netx
220+
200221
print("Test EvaluableExpressions")
201222
for i in range(7):
202223
assert eval("net.ee%i" % i) == eval("netj.ee%i" % i)
203224
assert eval("net.ee%i" % i) == eval("nety.ee%i" % i)
225+
assert eval("net.ee%i" % i) == eval("netx.ee%i" % i)
204226

205227

206228
def test_generate_documentation():
@@ -296,6 +318,7 @@ class Document(Base):
296318

297319
doc.to_json()
298320
doc.to_yaml()
321+
doc.to_xml()
299322
doc.generate_documentation(format="markdown")
300323
doc.generate_documentation(format="rst")
301324

@@ -314,10 +337,24 @@ class Node(Base):
314337
model.to_json()
315338

316339

340+
def test_ndarray_xml_metadata():
341+
import numpy as np
342+
343+
@modelspec.define(eq=False)
344+
class Node(Base):
345+
id: str = field(validator=instance_of(str))
346+
metadata: Optional[Dict[str, Any]] = field(
347+
kw_only=True, default=None, validator=optional(instance_of(dict))
348+
)
349+
350+
model = Node(id="a", metadata={"b": np.array([0])})
351+
model.to_xml()
352+
353+
317354
def test_bson_array(tmp_path):
318355
import numpy as np
319356

320-
test_filename = str(tmp_path / "test_array.bson")
357+
test_filename = str(Path(tmp_path) / "test_array.bson")
321358

322359
@modelspec.define(eq=False)
323360
class ArrayTestModel(Base):
@@ -339,3 +376,13 @@ class ArrayTestModel(Base):
339376
np.testing.assert_array_equal(model.array, m2.array)
340377
assert model.list_of_lists == m2.list_of_lists
341378
assert model.ragged_list == m2.ragged_list
379+
380+
381+
if __name__ == "__main__":
382+
test_save_load_json(".")
383+
test_generate_documentation()
384+
test_ndarray_json_metadata()
385+
test_ndarray_xml_metadata()
386+
test_generate_documentation_example()
387+
test_bson_array(".")
388+

0 commit comments

Comments
 (0)