Skip to content
6 changes: 6 additions & 0 deletions src/pynxtools/data/NXtest.nxdl.xml
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@
<item value="1st type open"/>
<item value="2nd type open"/>
</enumeration>
<attribute name="attribute_with_open_enum" optional="true">
<enumeration open="true">
<item value="1st option"/>
<item value="2nd option"/>
</enumeration>
</attribute>
</field>
<attribute name="group_attribute">
</attribute>
Expand Down
145 changes: 105 additions & 40 deletions src/pynxtools/dataconverter/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import logging
import os
import re
from collections.abc import Mapping, Sequence
from collections.abc import Mapping, MutableMapping, Sequence
from datetime import datetime, timezone
from enum import Enum, auto
from functools import cache, lru_cache
Expand Down Expand Up @@ -87,7 +87,9 @@ class ValidationProblem(Enum):
UnitWithoutDocumentation = auto()
InvalidUnit = auto()
InvalidEnum = auto()
OpenEnumWithNewItem = auto()
OpenEnumWithCustom = auto()
OpenEnumWithCustomFalse = auto()
OpenEnumWithMissingCustom = auto()
MissingRequiredGroup = auto()
MissingRequiredField = auto()
MissingRequiredAttribute = auto()
Expand Down Expand Up @@ -152,12 +154,27 @@ def _log(self, path: str, log_type: ValidationProblem, value: Optional[Any], *ar

elif log_type == ValidationProblem.InvalidEnum:
logger.warning(
f"The value at {path} should be one of the following: {value}."
f"The value '{args[0]}' at {path} should be one of the following: {value}."
)
elif log_type == ValidationProblem.OpenEnumWithNewItem:
elif log_type == ValidationProblem.OpenEnumWithCustom:
logger.info(
f"The value at {path} does not match with the enumerated items from the open enumeration: {value}."
f"The value '{args[0]}' at {path} does not match with the enumerated items from the open enumeration: {value}."
)
elif log_type == ValidationProblem.OpenEnumWithCustomFalse:
logger.warning(
f"The value '{args[0]}' at {path} does not match with the enumerated items from the open enumeration: {value}. "
"When a different value is used, the boolean 'custom' attribute cannot be False."
)
elif log_type == ValidationProblem.OpenEnumWithMissingCustom:
log_text = (
f"The value '{args[0]}' at {path} does not match with the enumerated items from the open enumeration: {value}. "
"When a different value is used, a boolean 'custom=True' attribute must be added."
)
if args[1] is True:
log_text += " It was added here automatically."
logger.info(log_text)
else:
logger.warning(log_text)
elif log_type == ValidationProblem.MissingRequiredGroup:
logger.warning(f"The required group {path} hasn't been supplied.")
elif log_type == ValidationProblem.MissingRequiredField:
Expand Down Expand Up @@ -287,9 +304,10 @@ def collect_and_log(
# info messages should not fail validation
if log_type in (
ValidationProblem.UnitWithoutDocumentation,
ValidationProblem.OpenEnumWithNewItem,
ValidationProblem.CompressionStrengthZero,
ValidationProblem.MissingNXclass,
ValidationProblem.OpenEnumWithCustom,
ValidationProblem.OpenEnumWithMissingCustom,
):
if self.logging and message not in self.data["info"]:
self._log(path, log_type, value, *args, **kwargs)
Expand Down Expand Up @@ -804,15 +822,11 @@ def convert_int_to_float(value):
return value


def is_valid_data_field(
value: Any, nxdl_type: str, nxdl_enum: list, nxdl_enum_open: bool, path: str
) -> Any:
def is_valid_data_field(value: Any, nxdl_type: str, path: str) -> Any:
"""Checks whether a given value is valid according to the type defined in the NXDL."""

def validate_data_value(
value: Any, nxdl_type: str, nxdl_enum: list, nxdl_enum_open: bool, path: str
) -> Any:
"""Validate and possibly convert a primitive value according to NXDL type/enum rules."""
def validate_data_value(value: Any, nxdl_type: str, path: str) -> Any:
"""Validate and possibly convert a primitive value according to NXDL type rules."""
accepted_types = NEXUS_TO_PYTHON_DATA_TYPES[nxdl_type]
original_value = value

Expand Down Expand Up @@ -843,26 +857,6 @@ def validate_data_value(
path, ValidationProblem.InvalidDatetime, value
)

if nxdl_enum is not None:
if (
isinstance(value, np.ndarray)
and isinstance(nxdl_enum, list)
and isinstance(nxdl_enum[0], list)
):
enum_value = list(value)
else:
enum_value = value

if enum_value not in nxdl_enum:
if nxdl_enum_open:
collector.collect_and_log(
path, ValidationProblem.OpenEnumWithNewItem, nxdl_enum
)
else:
collector.collect_and_log(
path, ValidationProblem.InvalidEnum, nxdl_enum
)

return value

if isinstance(value, dict) and set(value.keys()) == {"compress", "strength"}:
Expand All @@ -878,18 +872,89 @@ def validate_data_value(
path, ValidationProblem.InvalidCompressionStrength, value
)
# In this case, we remove the compression.
return validate_data_value(
value["compress"], nxdl_type, nxdl_enum, nxdl_enum_open, path
)
return validate_data_value(value["compress"], nxdl_type, path)

# Apply standard validation to compressed value
value["compress"] = validate_data_value(
compressed_value, nxdl_type, nxdl_enum, nxdl_enum_open, path
)
value["compress"] = validate_data_value(compressed_value, nxdl_type, path)

return value

return validate_data_value(value, nxdl_type, nxdl_enum, nxdl_enum_open, path)
return validate_data_value(value, nxdl_type, path)


def get_custom_attr_path(path: str) -> str:
if path.split("/")[-1].startswith("@"):
attr_name = path.split("/")[-1][1:] # remove "@"
return f"{path}_custom"
return f"{path}/@custom"


def is_valid_enum(
value: Any,
nxdl_enum: list,
nxdl_enum_open: bool,
path: str,
mapping: MutableMapping,
):
"""Check enumeration."""

if isinstance(value, dict) and set(value.keys()) == {"compress", "strength"}:
value = value["compress"]

if nxdl_enum is not None:
if (
isinstance(value, np.ndarray)
and isinstance(nxdl_enum, list)
and isinstance(nxdl_enum[0], list)
):
enum_value = list(value)
else:
enum_value = value

if enum_value not in nxdl_enum:
if nxdl_enum_open:
custom_path = get_custom_attr_path(path)

if isinstance(mapping, h5py.Group):
parent_path, attr_name = custom_path.rsplit("@", 1)
custom_attr = mapping.get(parent_path).attrs.get(attr_name)
custom_added_auto = False
else:
custom_attr = mapping.get(custom_path)
custom_added_auto = True

if custom_attr == True: # noqa: E712
collector.collect_and_log(
path,
ValidationProblem.OpenEnumWithCustom,
nxdl_enum,
value,
)
elif custom_attr == False: # noqa: E712
collector.collect_and_log(
path,
ValidationProblem.OpenEnumWithCustomFalse,
nxdl_enum,
value,
)

elif custom_attr is None:
try:
mapping[custom_path] = True
except ValueError:
# we are in the HDF5 validation, cannot set custom attribute.
pass
collector.collect_and_log(
path,
ValidationProblem.OpenEnumWithMissingCustom,
nxdl_enum,
value,
custom_added_auto,
)
else:
collector.collect_and_log(
path, ValidationProblem.InvalidEnum, nxdl_enum, value
)


def split_class_and_name_of(name: str) -> tuple[Optional[str], str]:
Expand Down
56 changes: 53 additions & 3 deletions src/pynxtools/dataconverter/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@
clean_str_attr,
collector,
convert_nexus_to_caps,
get_custom_attr_path,
is_valid_data_field,
is_valid_enum,
split_class_and_name_of,
)
from pynxtools.dataconverter.nexus_tree import (
Expand Down Expand Up @@ -644,9 +646,14 @@ def handle_field(
is_valid_data_field(
clean_str_attr(dataset[()]),
node.dtype,
full_path,
)
is_valid_enum(
clean_str_attr(dataset[()]),
node.items,
node.open_enum,
full_path,
data,
)

units = dataset.attrs.get("units")
Expand Down Expand Up @@ -695,7 +702,12 @@ def handle_attributes(
for attr_name in attrs:
full_path = f"{entry_name}/{path}/@{attr_name}"

if attr_name in ("NX_class", "units", "target"):
if attr_name in (
"NX_class",
"units",
"target",
"custom",
) or attr_name.endswith("_custom"):
# Ignore special attrs
continue

Expand Down Expand Up @@ -734,9 +746,15 @@ def handle_attributes(
is_valid_data_field(
attr_data,
node.dtype,
full_path,
)

is_valid_enum(
attr_data,
node.items,
node.open_enum,
full_path,
data,
)

def validate(path: str, h5_obj: Union[h5py.Group, h5py.Dataset]):
Expand Down Expand Up @@ -1342,10 +1360,16 @@ def handle_field(node: NexusNode, keys: Mapping[str, Any], prev_path: str):
mapping[variant_path] = is_valid_data_field(
keys[variant],
node.dtype,
variant_path,
)
is_valid_enum(
mapping[variant_path],
node.items,
node.open_enum,
variant_path,
mapping,
)
remove_from_not_visited(get_custom_attr_path(variant_path))

check_reserved_suffix(variant_path, keys)
check_reserved_prefix(variant_path, get_definition(variant_path), "field")
Expand Down Expand Up @@ -1410,10 +1434,16 @@ def handle_attribute(node: NexusNode, keys: Mapping[str, Any], prev_path: str):
f"{prev_path}/{variant if variant.startswith('@') else f'@{variant}'}"
],
node.dtype,
variant_path,
)
is_valid_enum(
mapping[variant_path],
node.items,
node.open_enum,
variant_path,
mapping,
)
remove_from_not_visited(get_custom_attr_path(variant_path))
check_reserved_prefix(
variant_path, get_definition(variant_path), "attribute"
)
Expand Down Expand Up @@ -1641,8 +1671,18 @@ def is_documented(key: str, tree: NexusNode) -> bool:
keys_to_remove.append(key)
return False
resolved_link[key] = is_valid_data_field(
resolved_link[key], node.dtype, node.items, node.open_enum, key
resolved_link[key],
node.dtype,
key,
)
is_valid_enum(
resolved_link[key],
node.items,
node.open_enum,
key,
mapping,
)
remove_from_not_visited(get_custom_attr_path(key))

return True

Expand All @@ -1656,8 +1696,18 @@ def is_documented(key: str, tree: NexusNode) -> bool:

# Check general validity
mapping[key] = is_valid_data_field(
mapping[key], node.dtype, node.items, node.open_enum, key
mapping[key],
node.dtype,
key,
)
is_valid_enum(
mapping[key],
node.items,
node.open_enum,
key,
mapping,
)
remove_from_not_visited(get_custom_attr_path(key))

# Check main field exists for units
if (
Expand Down
22 changes: 14 additions & 8 deletions src/pynxtools/testing/nexus_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,17 +241,24 @@ def load_logs(
def compare_logs(gen_lines: list[str], ref_lines: list[str]) -> None:
"""Compare log lines, ignoring specific differences."""

def get_section_ignore_lines(line: str) -> list[str]:
"""Return ignore lines for a section if the line starts with the section."""
section = line.rsplit(SECTION_SEPARATOR, 1)[-1].strip()
for key, ignore_lines in IGNORE_SECTIONS.items():
if section.startswith(key):
return ignore_lines

return []

def extra_lines(
lines1: list[str], lines2: list[str]
) -> list[Optional[str]]:
"""Return lines in lines1 but not in lines2, with line numbers and ignoring specified lines."""
diffs: list[Optional[str]] = []
"""Return lines in lines1 but not in lines2 with line numbers."""
diffs = []
section_ignore_lines = []
section = None
for ind, line in enumerate(lines1):
if line.startswith(SECTION_SEPARATOR):
section = line.rsplit(SECTION_SEPARATOR)[-1].strip()
section_ignore_lines = IGNORE_SECTIONS.get(section, [])
section_ignore_lines = get_section_ignore_lines(line)
if line not in lines2 and not should_skip_line(
line, ignore_lines=IGNORE_LINES + section_ignore_lines
):
Expand Down Expand Up @@ -282,13 +289,12 @@ def extra_lines(
# Case 2: same line counts, check for diffs
diffs = []
section_ignore_lines = []
section = None

for ind, (gen_l, ref_l) in enumerate(zip(gen_lines, ref_lines)):
if gen_l.startswith(SECTION_SEPARATOR) and ref_l.startswith(
SECTION_SEPARATOR
):
section = gen_l.rsplit(SECTION_SEPARATOR)[-1].strip()
section_ignore_lines = IGNORE_SECTIONS.get(section, [])
section_ignore_lines = get_section_ignore_lines(gen_l)
if gen_l != ref_l and not should_skip_line(
gen_l, ref_l, ignore_lines=IGNORE_LINES + section_ignore_lines
):
Expand Down
Loading
Loading