Skip to content

Commit 1e4a516

Browse files
committed
use hints for NXdata fields, log only once
1 parent a23b920 commit 1e4a516

File tree

6 files changed

+112
-82
lines changed

6 files changed

+112
-82
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,6 @@ validate_nexus = "pynxtools.dataconverter.validate_file:validate_cli"
112112

113113
[tool.setuptools.package-data]
114114
pynxtools = ["definitions/**/*.xml", "definitions/**/*.xsd"]
115-
"pynxtools.dataconverter.units" = ["*.txt"]
116115
"pynxtools.dataconverter.readers.hall" = ["enum_map.json"]
117116
"pynxtools.dataconverter.readers.rii_database.formula_parser" = ["dispersion_function_grammar.lark"]
118117

src/pynxtools/dataconverter/helpers.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,11 @@ class Collector:
8484
"""A class to collect data and return it in a dictionary format."""
8585

8686
def __init__(self):
87-
self.data = set()
87+
self.data: dict[str, set] = {
88+
"warning_and_error": set(),
89+
"info": set(),
90+
}
91+
8892
self.logging = True
8993

9094
def _log(self, path: str, log_type: ValidationProblem, value: Optional[Any], *args):
@@ -223,26 +227,36 @@ def collect_and_log(
223227
"NX_ANY",
224228
):
225229
return
226-
if self.logging and path + str(log_type) + str(value) not in self.data:
227-
self._log(path, log_type, value, *args, **kwargs)
230+
231+
message: str = path + str(log_type) + str(value)
232+
228233
# info messages should not fail validation
229-
if log_type not in (
234+
if log_type in (
230235
ValidationProblem.UnitWithoutDocumentation,
231236
ValidationProblem.OpenEnumWithNewItem,
232237
):
233-
self.data.add(path + str(log_type) + str(value))
238+
if self.logging and message not in self.data["info"]:
239+
self._log(path, log_type, value, *args, **kwargs)
240+
self.data["info"].add(message)
241+
else:
242+
if self.logging and message not in self.data["warning_and_error"]:
243+
self._log(path, log_type, value, *args, **kwargs)
244+
self.data["warning_and_error"].add(message)
234245

235246
def has_validation_problems(self) -> bool:
236247
"""Returns True if there were any validation problems."""
237-
return len(self.data) > 0
248+
return len(self.data["warning_and_error"]) > 0
238249

239250
def get(self):
240251
"""Returns the set of problematic paths."""
241-
return self.data
252+
return self.data["warning_and_error"]
242253

243254
def clear(self):
244255
"""Clears the collected data."""
245-
self.data = set()
256+
self.data: dict[str, set] = {
257+
"warning_and_error": set(),
258+
"info": set(),
259+
}
246260

247261

248262
collector = Collector()

src/pynxtools/dataconverter/validate_file.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,12 +95,12 @@ def validate(file: str, ignore_undocumented: bool = False):
9595

9696
if is_valid:
9797
logger.info(
98-
f"The entry `{entry}` in file `{file}` is a valid file"
98+
f"The entry `{entry}` in file `{file}` is valid"
9999
f" according to the `{nxdl}` application definition.",
100100
)
101101
else:
102102
logger.info(
103-
f"Invalid: The entry `{entry}` in file `{file}` is NOT a valid file"
103+
f"Invalid: The entry `{entry}` in file `{file}` is NOT valid"
104104
f" according to the `{nxdl}` application definition.",
105105
)
106106

src/pynxtools/dataconverter/validation.py

Lines changed: 81 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,7 @@ def validate_hdf_group_against(
200200
def best_namefit_of(
201201
name: str,
202202
nodes: Iterable[NexusNode],
203+
hint: Optional[Literal["axis", "signal"]] = None,
203204
) -> Optional[NexusNode]:
204205
"""
205206
Get the best namefit of `name` in `nodes`.
@@ -218,6 +219,8 @@ def best_namefit_of(
218219
best_match = None
219220
best_score = -1
220221

222+
hint_map: dict[str, str] = {"DATA": "signal", "AXISNAME": "axis"}
223+
221224
for node in nodes:
222225
if not node.variadic:
223226
if name == node.name:
@@ -227,22 +230,28 @@ def best_namefit_of(
227230
name_partial = node.name_type == "partial"
228231
score = get_nx_namefit(name, node.name, name_any, name_partial)
229232
if score > best_score:
233+
if hint and hint_map.get(node.name) != hint:
234+
continue
230235
best_match = node
231236
best_score = score
232237

238+
# if hint:
239+
# print(hint, node, name)
240+
233241
return best_match
234242

235243
# Only cache based on path. That way we retain the nx_class information
236244
# in the tree
237245
# Allow for 10000 cache entries. This should be enough for most cases
238246
@cached(
239247
cache=LRUCache(maxsize=10000),
240-
key=lambda path, node_type=None, nx_class=None: hashkey(path),
248+
key=lambda path, node_type=None, nx_class=None, hint=None: hashkey(path),
241249
)
242250
def find_node_for(
243251
path: str,
244252
node_type: Optional[Literal["group", "field", "attribute"]] = None,
245253
nx_class: Optional[str] = None,
254+
hint: Optional[Literal["axis", "signal"]] = None,
246255
) -> Optional[NexusNode]:
247256
"""
248257
Find the NexusNode for a given HDF5 path, optionally constrained by node type and NX_class.
@@ -264,7 +273,7 @@ def find_node_for(
264273

265274
*prev_path, last_elem = path.rsplit("/", 1)
266275

267-
node = find_node_for(prev_path[0]) if prev_path else tree
276+
node = find_node_for(prev_path[0], hint=hint) if prev_path else tree
268277
current = copy.copy(node)
269278

270279
if node is None:
@@ -276,7 +285,7 @@ def find_node_for(
276285
nx_class=nx_class, node_type=node_type
277286
)
278287
]
279-
node = best_namefit_of(last_elem, children_to_check)
288+
node = best_namefit_of(last_elem, children_to_check, hint)
280289

281290
if node is None:
282291
# Check that there is no other node with the same name, but a different type
@@ -563,21 +572,25 @@ def check_reserved_prefix(
563572

564573
return
565574

566-
def handle_group(path: str, data: h5py.Group):
575+
def handle_group(path: str, group: h5py.Group):
567576
"""
568577
Handle validation logic for HDF5 groups.
569578
570579
Args:
571580
path (str): Relative HDF5 path to the group.
572-
data (h5py.Group): The group object.
581+
group (h5py.Group): The group object.
573582
"""
574583
full_path = f"{entry_name}/{path}"
575584

576585
check_reserved_prefix(full_path, appdef_node.name, "group")
577586

587+
if not group.attrs.get("NX_class"):
588+
# We ignore additional groups that don't have an NX_class
589+
return
590+
578591
try:
579592
node = find_node_for(
580-
path, node_type="group", nx_class=data.attrs.get("NX_class")
593+
path, node_type="group", nx_class=group.attrs.get("NX_class")
581594
)
582595
except TypeError:
583596
return
@@ -596,22 +609,22 @@ def handle_group(path: str, data: h5py.Group):
596609
return
597610

598611
if node.nx_class == "NXdata":
599-
handle_nxdata(path, data)
612+
handle_nxdata(path, group)
600613
if node.nx_class == "NXcollection":
601614
return
602615

603-
def handle_nxdata(path: str, data: h5py.Group):
616+
def handle_nxdata(path: str, group: h5py.Group):
604617
"""
605618
Handle validation of NXdata groups, including signal, axes, and auxiliary signals.
606619
607620
Args:
608621
path (str): HDF5 path to the NXdata group.
609-
data (h5py.Group): The NXdata group object.
622+
group (h5py.Group): The NXdata group object.
610623
"""
611624
full_path = f"{entry_name}/{path}"
612625

613626
def check_nxdata():
614-
data_field = data.get(signal)
627+
data_field = group.get(signal)
615628

616629
if data_field is None:
617630
collector.collect_and_log(
@@ -620,23 +633,20 @@ def check_nxdata():
620633
None,
621634
)
622635
else:
623-
handle_field(
624-
f"{path}/{signal}",
625-
data_field,
626-
)
636+
handle_field(f"{path}/{signal}", data_field, hint="signal")
627637

628638
# check NXdata attributes
629639
attrs = ("signal", "auxiliary_signals", "axes")
630-
data_attrs = {k: data.attrs[k] for k in attrs if k in data.attrs}
640+
data_attrs = {k: group.attrs[k] for k in attrs if k in group.attrs}
631641

632642
handle_attributes(path, data_attrs)
633643

634644
for i, axis in enumerate(axes):
635645
if axis == ".":
636646
continue
637-
index = data.get(f"{axis}_indices", i)
647+
index = group.get(f"{axis}_indices", i)
638648

639-
axis_field = data.get(axis)
649+
axis_field = group.get(axis)
640650

641651
if axis_field is None:
642652
collector.collect_and_log(
@@ -646,10 +656,7 @@ def check_nxdata():
646656
)
647657
break
648658
else:
649-
handle_field(
650-
f"{path}/{axis}",
651-
data_field,
652-
)
659+
handle_field(f"{path}/{axis}", axis_field, hint="axis")
653660
if np.shape(data_field)[index] != len(axis_field):
654661
collector.collect_and_log(
655662
f"{path}/{axis}",
@@ -658,45 +665,57 @@ def check_nxdata():
658665
index,
659666
)
660667

661-
signal = data.attrs.get("signal")
662-
aux_signals = data.attrs.get("auxiliary_signals", [])
663-
axes = data.attrs.get("axes", [])
668+
signal = group.attrs.get("signal")
669+
aux_signals = group.attrs.get("auxiliary_signals", [])
670+
axes = group.attrs.get("axes", [])
664671

665672
if isinstance(axes, str):
666673
axes = [axes]
667674

668-
if signal is not None:
669-
check_nxdata()
670-
671675
indices = map(lambda x: f"{x}_indices", axes)
672676
errors = map(lambda x: f"{x}_errors", [signal, *aux_signals, *axes])
673677

674-
def handle_field(path: str, data: h5py.Dataset):
678+
# TODO: check that the indices match
679+
# TODO: check that the errors have the same dim as the fields
680+
681+
if signal is not None:
682+
check_nxdata()
683+
684+
def handle_field(
685+
path: str,
686+
dataset: h5py.Dataset,
687+
hint: Optional[Literal["axis", "signal"]] = None,
688+
):
675689
"""
676690
Validate a NeXus field (dataset) within the HDF5 structure.
677691
678692
Args:
679693
path (str): Path to the dataset.
680694
data (h5py.Dataset): Dataset object.
695+
hint (str):
696+
If the field is in an NXdata group, this is used to figure out
697+
if it is an AXISNAME or a DATA.
681698
"""
699+
682700
full_path = f"{entry_name}/{path}"
683701
check_reserved_prefix(full_path, appdef_node.name, "field")
684702
try:
685-
node = find_node_for(path, node_type="field")
703+
node = find_node_for(path, node_type="field", hint=hint)
686704
except TypeError:
687705
return
706+
688707
if node is None:
689708
key_path = path.replace("@", "")
709+
parent_node = None
690710
while "/" in key_path:
691711
key_path = key_path.rsplit("/", 1)[0] # Remove last segment
692-
parent_node = find_node_for(path, node_type="field")
693-
if parent_node is None:
694-
parent_node = find_node_for(key_path, node_type="group")
695-
if (
696-
parent_node
697-
and parent_node.type == "group"
698-
and parent_node.nx_class == "NXcollection"
699-
):
712+
parent_data = data.get(key_path)
713+
nx_class = (
714+
parent_data.attrs.get("NX_class")
715+
if parent_data is not None
716+
else None
717+
)
718+
if nx_class == "NXcollection":
700719
# Collection found for parents, mark as documented
701720
return
702721

@@ -714,10 +733,14 @@ def handle_field(path: str, data: h5py.Dataset):
714733
return
715734

716735
is_valid_data_field(
717-
clean_str_attr(data[()]), node.dtype, node.items, node.open_enum, full_path
736+
clean_str_attr(dataset[()]),
737+
node.dtype,
738+
node.items,
739+
node.open_enum,
740+
full_path,
718741
)
719742

720-
units = data.attrs.get("units")
743+
units = dataset.attrs.get("units")
721744
units_path = f"{full_path}/@units"
722745
if node.unit is not None:
723746
remove_from_req_entities(f"{path}/@units")
@@ -730,7 +753,7 @@ def handle_field(path: str, data: h5py.Dataset):
730753
return
731754
# Special case: NX_TRANSFORMATION unit depends on `@transformation_type` attribute
732755
if (
733-
transformation_type := data.attrs.get("transformation_type")
756+
transformation_type := dataset.attrs.get("transformation_type")
734757
) is not None:
735758
hints = {"transformation_type": transformation_type}
736759
else:
@@ -766,26 +789,26 @@ def handle_attributes(path: str, attrs: h5py.AttributeManager):
766789
node = find_node_for(f"{path}/{attr_name}", node_type="attribute")
767790
except TypeError:
768791
return
769-
if node is None:
770-
key_path = full_path.replace("@", "")
771-
found_collection = False
772-
while "/" in key_path:
773-
key_path = key_path.rsplit("/", 1)[0] # Remove last segment
774-
parent_node = find_node_for(path, node_type="field")
775-
if parent_node is None:
776-
parent_node = find_node_for(key_path, node_type="group")
777792

778-
if (
779-
parent_node
780-
and parent_node.type == "group"
781-
and parent_node.nx_class == "NXcollection"
782-
):
783-
# Collection found for parents, mark as documented
784-
found_collection = True
785-
break
793+
key_path = f"{path}/{attr_name}"
794+
parent_node = None
795+
found_collection = False
796+
while "/" in key_path:
797+
key_path = key_path.rsplit("/", 1)[0] # Remove last segment
798+
parent_data = data.get(key_path)
799+
nx_class = (
800+
parent_data.attrs.get("NX_class")
801+
if parent_data is not None
802+
else None
803+
)
804+
if nx_class == "NXcollection":
805+
# Collection found for parents, mark as documented
806+
found_collection = True
807+
break
808+
if found_collection:
809+
continue # This continues the outer attr_name loop
786810

787-
if found_collection:
788-
continue # This continues the outer attr_name loop
811+
if node is None:
789812
if not ignore_undocumented:
790813
collector.collect_and_log(
791814
full_path,

src/pynxtools/testing/nexus_conversion.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ def convert_to_nexus(
164164

165165
assert test_output == []
166166

167-
# Validate created file using the verify_nexus functionality
167+
# Validate created file using the validate_nexus functionality
168168
validate(self.created_nexus, ignore_undocumented=ignore_undocumented)
169169

170170
if NOMAD_AVAILABLE:

0 commit comments

Comments
 (0)