Skip to content

Commit 70d74b3

Browse files
authored
Support for appdef extends (#339)
* Read extends keyword from file * Insert extends parents into the inheritance chain * Automatically populate tree from appdef parents * Only populate tree if parents are present * Docstring improvements * Fix exact match in NX_CLASS[path] notation * If minOccurs == 0, set the group to optional * Add extended NXtest
1 parent 3a7d63d commit 70d74b3

File tree

5 files changed

+205
-20
lines changed

5 files changed

+205
-20
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<?xml-stylesheet type="text/xsl" href="nxdlformat.xsl" ?>
3+
<definition xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
4+
xsi:schemaLocation="http://definition.nexusformat.org/nxdl/3.1 ../nxdl.xsd"
5+
xmlns="http://definition.nexusformat.org/nxdl/3.1"
6+
name="NXtest_extended"
7+
extends="NXtest"
8+
type="group"
9+
category="application"
10+
>
11+
<doc>This is a dummy NXDL to test an extended application definition.</doc>
12+
<group type="NXentry">
13+
<field name="definition">
14+
<doc>This is a dummy NXDL to test out the dataconverter.</doc>
15+
<enumeration>
16+
<item value="NXtest_extended"/>
17+
</enumeration>
18+
</field>
19+
<field name="extended_field" type="NX_FLOAT" units="NX_ENERGY">
20+
<doc>A dummy entry for an extended field.</doc>
21+
</field>
22+
</group>
23+
</definition>

src/pynxtools/dataconverter/helpers.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,75 @@ def get_nxdl_name_from_elem(xml_element) -> str:
193193
return name_to_add
194194

195195

196+
def get_nxdl_name_for(xml_elem: ET._Element) -> Optional[str]:
197+
"""
198+
Get the name of the element from the NXDL element.
199+
For an entity having a name this is just the name.
200+
For groups it is the uppercase type without NX, e.g. "ENTRY" for "NXentry".
201+
202+
Args:
203+
xml_elem (ET._Element): The xml element to get the name for.
204+
205+
Returns:
206+
Optional[str]:
207+
The name of the element.
208+
None if the xml element has no name or type attribute.
209+
"""
210+
""""""
211+
if "name" in xml_elem.attrib:
212+
return xml_elem.attrib["name"]
213+
if "type" in xml_elem.attrib:
214+
return convert_nexus_to_caps(xml_elem.attrib["type"])
215+
return None
216+
217+
218+
def get_appdef_root(xml_elem: ET._Element) -> ET._Element:
219+
"""
220+
Get the root element of the tree of xml_elem
221+
222+
Args:
223+
xml_elem (ET._Element): The element for which to get the root element.
224+
225+
Returns:
226+
ET._Element: The root element of the tree.
227+
"""
228+
return xml_elem.getroottree().getroot()
229+
230+
231+
def is_appdef(xml_elem: ET._Element) -> bool:
232+
"""
233+
Check whether the xml element is part of an application definition.
234+
235+
Args:
236+
xml_elem (ET._Element): The xml_elem whose tree to check.
237+
238+
Returns:
239+
bool: True if the xml_elem is part of an application definition.
240+
"""
241+
return get_appdef_root(xml_elem).attrib.get("category") == "application"
242+
243+
244+
def get_all_parents_for(xml_elem: ET._Element) -> List[ET._Element]:
245+
"""
246+
Get all parents from the nxdl (via extends keyword)
247+
248+
Args:
249+
xml_elem (ET._Element): The element to get the parents for.
250+
251+
Returns:
252+
List[ET._Element]: The list of parents xml nodes.
253+
"""
254+
root = get_appdef_root(xml_elem)
255+
inheritance_chain = []
256+
extends = root.get("extends")
257+
while extends is not None and extends != "NXobject":
258+
parent_xml_root, _ = get_nxdl_root_and_path(extends)
259+
extends = parent_xml_root.get("extends")
260+
inheritance_chain.append(parent_xml_root)
261+
262+
return inheritance_chain
263+
264+
196265
def get_nxdl_root_and_path(nxdl: str):
197266
"""Get xml root element and file path from nxdl name e.g. NXapm.
198267
@@ -213,16 +282,20 @@ def get_nxdl_root_and_path(nxdl: str):
213282
FileNotFoundError
214283
Error if no file with the given nxdl name is found.
215284
"""
285+
216286
# Reading in the NXDL and generating a template
217287
definitions_path = nexus.get_nexus_definitions_path()
218-
if nxdl == "NXtest":
219-
nxdl_f_path = os.path.join(
220-
f"{os.path.abspath(os.path.dirname(__file__))}/../",
221-
"data",
222-
"NXtest.nxdl.xml",
223-
)
224-
elif nxdl == "NXroot":
225-
nxdl_f_path = os.path.join(definitions_path, "base_classes", "NXroot.nxdl.xml")
288+
data_path = os.path.join(
289+
f"{os.path.abspath(os.path.dirname(__file__))}/../",
290+
"data",
291+
)
292+
special_names = {
293+
"NXtest": os.path.join(data_path, "NXtest.nxdl.xml"),
294+
"NXtest_extended": os.path.join(data_path, "NXtest_extended.nxdl.xml"),
295+
}
296+
297+
if nxdl in special_names:
298+
nxdl_f_path = special_names[nxdl]
226299
else:
227300
nxdl_f_path = os.path.join(
228301
definitions_path, "contributed_definitions", f"{nxdl}.nxdl.xml"

src/pynxtools/dataconverter/nexus_tree.py

Lines changed: 43 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,10 @@
3535

3636
from pynxtools.dataconverter.helpers import (
3737
contains_uppercase,
38+
get_all_parents_for,
39+
get_nxdl_name_for,
3840
get_nxdl_root_and_path,
41+
is_appdef,
3942
remove_namespace_from_tag,
4043
)
4144

@@ -92,7 +95,7 @@
9295
]
9396

9497
# This is the NeXus namespace for finding tags.
95-
# It's updated from the nxdl file when `generate_tree_from`` is called.
98+
# It's updated from the nxdl file when `generate_tree_from` is called.
9699
namespaces = {"nx": "http://definition.nexusformat.org/nxdl/3.1"}
97100

98101

@@ -117,10 +120,6 @@ class NexusNode(NodeMixin):
117120
This is set automatically on init and will be True if the name contains
118121
any uppercase characets and False otherwise.
119122
Defaults to False.
120-
variadic_siblings (List[InstanceOf["NexusNode"]]):
121-
Variadic siblings are names which are connected to each other, e.g.,
122-
`AXISNAME` and `AXISNAME_indices` belong together and are variadic siblings.
123-
Defaults to [].
124123
inheritance (List[InstanceOf[ET._Element]]):
125124
The inheritance chain of the node.
126125
The first element of the list is the xml representation of this node.
@@ -146,7 +145,10 @@ def _set_optionality(self):
146145
return
147146
if self.inheritance[0].attrib.get("recommended"):
148147
self.optionality = "recommended"
149-
elif self.inheritance[0].attrib.get("optional"):
148+
elif (
149+
self.inheritance[0].attrib.get("optional")
150+
or self.inheritance[0].attrib.get("minOccurs") == "0"
151+
):
150152
self.optionality = "optional"
151153

152154
def __init__(
@@ -223,7 +225,9 @@ def search_child_with_name(
223225
return self.add_inherited_node(name)
224226
return None
225227

226-
def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:
228+
def get_all_children_names(
229+
self, depth: Optional[int] = None, only_appdef: bool = False
230+
) -> Set[str]:
227231
"""
228232
Get all children names of the current node up to a certain depth.
229233
Only `field`, `group` `choice` or `attribute` are considered as children.
@@ -234,6 +238,9 @@ def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:
234238
`depth=1` will return only the children of the current node.
235239
`depth=None` will return all children names of all parents.
236240
Defaults to None.
241+
only_appdef (bool, optional):
242+
Only considers appdef nodes as children.
243+
Defaults to False.
237244
238245
Raises:
239246
ValueError: If depth is not int or negativ.
@@ -246,6 +253,9 @@ def get_all_children_names(self, depth: Optional[int] = None) -> Set[str]:
246253

247254
names = set()
248255
for elem in self.inheritance[:depth]:
256+
if only_appdef and not is_appdef(elem):
257+
break
258+
249259
for subelems in elem.xpath(
250260
(
251261
r"*[self::nx:field or self::nx:group "
@@ -354,6 +364,7 @@ def _build_inheritance_chain(self, xml_elem: ET._Element) -> List[ET._Element]:
354364
inheritance_chain.append(inherited_elem[0])
355365
bc_xml_root, _ = get_nxdl_root_and_path(xml_elem.attrib["type"])
356366
inheritance_chain.append(bc_xml_root)
367+
inheritance_chain += get_all_parents_for(bc_xml_root)
357368

358369
return inheritance_chain
359370

@@ -371,13 +382,15 @@ def add_node_from(self, xml_elem: ET._Element) -> Optional["NexusNode"]:
371382
The children node which was added.
372383
None if the tag of the xml element is not known.
373384
"""
385+
default_optionality = "required" if is_appdef(xml_elem) else "optional"
374386
tag = remove_namespace_from_tag(xml_elem.tag)
375387
if tag in ("field", "attribute"):
376388
name = xml_elem.attrib.get("name")
377389
current_elem = NexusEntity(
378390
parent=self,
379391
name=name,
380392
type=tag,
393+
optionality=default_optionality,
381394
)
382395
elif tag == "group":
383396
name = xml_elem.attrib.get("name", xml_elem.attrib["type"][2:].upper())
@@ -388,12 +401,14 @@ def add_node_from(self, xml_elem: ET._Element) -> Optional["NexusNode"]:
388401
name=name,
389402
nx_class=xml_elem.attrib["type"],
390403
inheritance=inheritance_chain,
404+
optionality=default_optionality,
391405
)
392406
elif tag == "choice":
393407
current_elem = NexusChoice(
394408
parent=self,
395409
name=xml_elem.attrib["name"],
396410
variadic=contains_uppercase(xml_elem.attrib["name"]),
411+
optionality=default_optionality,
397412
)
398413
else:
399414
# TODO: Tags: link
@@ -428,7 +443,6 @@ def add_inherited_node(self, name: str) -> Optional["NexusNode"]:
428443
)
429444
if xml_elem:
430445
new_node = self.add_node_from(xml_elem[0])
431-
new_node.optionality = "optional"
432446
return new_node
433447
return None
434448

@@ -616,6 +630,19 @@ def __repr__(self) -> str:
616630
return f"{self.name} ({self.optionality[:3]})"
617631

618632

633+
def populate_tree_from_parents(node: NexusNode):
634+
"""
635+
Recursively populate the tree from the appdef parents (via extends keyword).
636+
637+
Args:
638+
node (NexusNode):
639+
The current node from which to populate the tree.
640+
"""
641+
for child in node.get_all_children_names(only_appdef=True):
642+
child_node = node.search_child_with_name(child)
643+
populate_tree_from_parents(child_node)
644+
645+
619646
def generate_tree_from(appdef: str) -> NexusNode:
620647
"""
621648
Generates a NexusNode tree from an application definition.
@@ -655,14 +682,17 @@ def add_children_to(parent: NexusNode, xml_elem: ET._Element) -> None:
655682
global namespaces
656683
namespaces = {"nx": appdef_xml_root.nsmap[None]}
657684

685+
appdef_inheritance_chain = [appdef_xml_root]
686+
appdef_inheritance_chain += get_all_parents_for(appdef_xml_root)
687+
658688
tree = NexusGroup(
659689
name=appdef_xml_root.attrib["name"],
660690
nx_class="NXroot",
661691
type="group",
662692
optionality="required",
663693
variadic=False,
664694
parent=None,
665-
inheritance=[appdef_xml_root],
695+
inheritance=appdef_inheritance_chain,
666696
)
667697
# Set root attributes
668698
nx_root, _ = get_nxdl_root_and_path("NXroot")
@@ -673,4 +703,8 @@ def add_children_to(parent: NexusNode, xml_elem: ET._Element) -> None:
673703
entry = appdef_xml_root.find("nx:group[@type='NXentry']", namespaces=namespaces)
674704
add_children_to(tree, entry)
675705

706+
# Add all fields and attributes from the parent appdefs
707+
if len(appdef_inheritance_chain) > 1:
708+
populate_tree_from_parents(tree)
709+
676710
return tree

src/pynxtools/dataconverter/validation.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
Collector,
3232
ValidationProblem,
3333
collector,
34+
convert_nexus_to_caps,
3435
is_valid_data_field,
3536
)
3637
from pynxtools.dataconverter.nexus_tree import (
@@ -186,8 +187,14 @@ def validate_dict_against(
186187
"""
187188

188189
def get_variations_of(node: NexusNode, keys: Mapping[str, Any]) -> List[str]:
189-
if not node.variadic and node.name in keys:
190-
return [node.name]
190+
if not node.variadic:
191+
if node.name in keys:
192+
return [node.name]
193+
elif (
194+
hasattr(node, "nx_class")
195+
and f"{convert_nexus_to_caps(node.nx_class)}[{node.name}]" in keys
196+
):
197+
return [f"{convert_nexus_to_caps(node.nx_class)}[{node.name}]"]
191198

192199
variations = []
193200
for key in keys:

tests/dataconverter/test_nexus_tree.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from typing import get_args
1+
from typing import Any, List, Tuple, get_args
22

3+
from anytree import Resolver
34
from pynxtools.dataconverter.nexus_tree import (
5+
NexusNode,
46
NexusType,
57
NexusUnitCategory,
68
generate_tree_from,
@@ -31,3 +33,49 @@ def test_if_all_types_are_present():
3133
pydantic_literal_values = get_args(NexusType)
3234

3335
assert set(reference_types) == set(pydantic_literal_values)
36+
37+
38+
def test_correct_extension_of_tree():
39+
nxtest = generate_tree_from("NXtest")
40+
nxtest_extended = generate_tree_from("NXtest_extended")
41+
42+
def get_node_fields(tree: NexusNode) -> List[Tuple[str, Any]]:
43+
return list(
44+
filter(
45+
lambda x: not x[0].startswith("_") and x[0] not in "inheritance",
46+
tree.__dict__.items(),
47+
)
48+
)
49+
50+
def left_tree_in_right_tree(left_tree, right_tree):
51+
for left_child in left_tree.children:
52+
if left_child.name not in map(lambda x: x.name, right_tree.children):
53+
return False
54+
right_child = list(
55+
filter(lambda x: x.name == left_child.name, right_tree.children)
56+
)[0]
57+
if left_child.name == "definition":
58+
# Definition should be overwritten
59+
if not left_child.items == ["NXTEST", "NXtest"]:
60+
return False
61+
if not right_child.items == ["NXtest_extended"]:
62+
return False
63+
continue
64+
for field in get_node_fields(left_child):
65+
if field not in get_node_fields(right_child):
66+
return False
67+
if not left_tree_in_right_tree(left_child, right_child):
68+
return False
69+
return True
70+
71+
assert left_tree_in_right_tree(nxtest, nxtest_extended)
72+
73+
resolver = Resolver("name", relax=True)
74+
extended_field = resolver.get(nxtest_extended, "ENTRY/extended_field")
75+
assert extended_field is not None
76+
assert extended_field.unit == "NX_ENERGY"
77+
assert extended_field.dtype == "NX_FLOAT"
78+
assert extended_field.optionality == "required"
79+
80+
nxtest_field = resolver.get(nxtest, "ENTRY/extended_field")
81+
assert nxtest_field is None

0 commit comments

Comments
 (0)