Skip to content
Open
6 changes: 3 additions & 3 deletions .github/workflows/plugin_test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ jobs:
branch: main
tests_to_run: tests/.
- plugin: pynxtools-spm
branch: main
branch: RepalceNexusFile
tests_to_run: tests/.
- plugin: pynxtools-xps
branch: main
branch: adapt-testing
tests_to_run: tests/.
- plugin: pynxtools-xrd
branch: main
branch: RepalceNexusFile
tests_to_run: tests/.

steps:
Expand Down
95 changes: 90 additions & 5 deletions src/pynxtools/nexus/nexus.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,6 +425,91 @@ def get_inherited_hdf_nodes(
return (class_path, nxdl_elem_path, elist)


def safe_str(value, precision: int = 8) -> str:
"""Return a deterministic string representation of arrays, lists, or scalars.

Floats are formatted consistently across systems to ensure deterministic
output. Special handling is applied to simplify representation:
- `0.0` → `'0.0'`
- `1.0` → `'1'`
- `1.50` → `'1.5'`
- Non-integer floats keep up to `precision` decimals with trailing zeros
and dots removed.

Arrays and lists are formatted elementwise using the same rules.

Args:
value: The input value to format. Can be a scalar, list, tuple,
NumPy array, or basic type such as int, float, str, or bytes.
precision (int): Maximum number of decimal places for non-integer
floats. Defaults to 8.

Returns:
str: Deterministic string representation of the input.
"""
# Normalize NumPy scalar and 0D array types
if isinstance(value, np.generic):
value = value.item()
elif isinstance(value, np.ndarray) and value.shape == ():
value = value.item()

def format_float(value: float) -> str:
"""Format a float deterministically."""
if value == 0.0:
return "0.0"
if value.is_integer():
return str(int(value))
if abs(value) < 10**-precision or abs(value) >= 10 ** (precision + 1):
return f"{value:.{precision}e}"
return f"{value:.{precision}f}".rstrip("0").rstrip(".")

# --- Arrays ---
if isinstance(value, np.ndarray):
flat = value.flatten()
formatted = []
for v in flat:
if isinstance(v, (np.generic, np.ndarray)):
v = v.item()
if isinstance(v, float):
formatted.append(format_float(v))
elif isinstance(v, (int, bool)):
formatted.append(str(v))
elif isinstance(v, str):
formatted.append(v)
elif isinstance(v, bytes):
formatted.append(v.decode(errors="replace"))
else:
formatted.append(str(v))
reshaped = np.array(formatted, dtype=object).reshape(value.shape)
return np.array2string(
reshaped,
separator=", ",
formatter={"all": lambda x: str(x)},
max_line_width=1000000,
threshold=6,
)

# --- Lists / tuples ---
if isinstance(value, list | tuple):
formatted = [safe_str(v, precision) for v in value]
return "[" + ", ".join(formatted) + "]"

# --- Floats ---
if isinstance(value, float | np.floating):
return format_float(float(value))

# --- Integers / booleans ---
elif isinstance(value, (int, np.integer, bool, np.bool_)):
return str(value)

# --- Strings / bytes ---
elif isinstance(value, (bytes, str)):
return value if isinstance(value, str) else value.decode(errors="replace")

# --- Fallback ---
return str(value)


def process_node(hdf_node, hdf_path, parser, logger, doc=True):
"""Processes an hdf5 node.
- it logs the node found and also checks for its attributes
Expand All @@ -436,11 +521,11 @@ def process_node(hdf_node, hdf_path, parser, logger, doc=True):
if isinstance(hdf_node, h5py.Dataset):
logger.debug(f"===== FIELD (/{hdf_path}): {hdf_node}")
val = (
str(decode_if_string(hdf_node[()])).split("\n")
safe_str(decode_if_string(hdf_node[()])).split("\n")
if len(hdf_node.shape) <= 1
else str(decode_if_string(hdf_node[0])).split("\n")
else safe_str(decode_if_string(hdf_node[0])).split("\n")
)
logger.debug(f"value: {val[0]} {'...' if len(val) > 1 else ''}")
logger.debug(f"value: {val[0]}{' ...' if len(val) > 1 else ''}")
else:
logger.debug(
f"===== GROUP (/{hdf_path} "
Expand All @@ -460,8 +545,8 @@ def process_node(hdf_node, hdf_path, parser, logger, doc=True):
)
for key, value in hdf_node.attrs.items():
logger.debug(f"===== ATTRS (/{hdf_path}@{key})")
val = str(decode_if_string(value)).split("\n")
logger.debug(f"value: {val[0]} {'...' if len(val) > 1 else ''}")
val = safe_str(decode_if_string(value)).split("\n")
logger.debug(f"value: {val[0]}{' ...' if len(val) > 1 else ''}")
(req_str, nxdef, nxdl_path) = get_nxdl_doc(hdf_info, logger, doc, attr=key)
if (
parser is not None
Expand Down
2 changes: 1 addition & 1 deletion src/pynxtools/nomad/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def _get_value(hdf_node):
def decode_array(arr):
result = []
for x in arr:
if isinstance(x, (np.ndarray, list)):
if isinstance(x, np.ndarray | list):
result.append(decode_array(x))
else:
result.append(str(decode_or_not(x)))
Expand Down
46 changes: 32 additions & 14 deletions src/pynxtools/testing/nexus_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,14 @@ class ReaderTest:
"""Generic test for reader plugins."""

def __init__(
self, nxdl, reader_name, files_or_dir, tmp_path, caplog, **kwargs
self,
nxdl,
reader_name,
files_or_dir,
tmp_path,
caplog,
ref_log_path=None,
**kwargs,
) -> None:
"""Initialize the test object.

Expand All @@ -75,13 +82,15 @@ def __init__(
files_or_dir : str
List of input files or full path string to the example data directory that contains all the files
required for running the data conversion through the reader.
ref_nexus_file : str
Full path string to the reference NeXus file generated from the same
set of input files.
tmp_path : pathlib.PosixPath
Pytest fixture variable, used to clean up the files generated during the test.
caplog : _pytest.logging.LogCaptureFixture
Pytest fixture variable, used to capture the log messages during the test.
ref_log_path : str
Full path string to the reference log file generated from the same
set of input files in files_or_dir. This can also be parsed automatically if
files_or_dir is the full path string to the example data directory and there
is only one reference log file.
kwargs : dict[str, Any]
Any additional keyword arguments to be passed to the readers' read function.
"""
Expand All @@ -91,11 +100,11 @@ def __init__(
self.reader = get_reader(self.reader_name)

self.files_or_dir = files_or_dir
self.ref_nexus_file = ""
self.tmp_path = tmp_path
self.caplog = caplog
self.created_nexus = f"{tmp_path}/{os.sep}/output.nxs"
self.ref_log_path = ref_log_path
self.kwargs = kwargs
self.created_nexus = f"{tmp_path}/{os.sep}/output.nxs"

def convert_to_nexus(
self,
Expand All @@ -114,15 +123,18 @@ def convert_to_nexus(
example_files = self.files_or_dir
else:
example_files = sorted(glob(os.path.join(self.files_or_dir, "*")))
self.ref_nexus_file = [file for file in example_files if file.endswith(".nxs")][
0
]

if not self.ref_log_path:
self.ref_log_path = next(
(file for file in example_files if file.endswith(".log")), None
)
assert self.ref_log_path, "Reference nexus .log file not found"

input_files = [
file
for file in example_files
if not file.endswith((".nxs", "ref_output.txt"))
if not file.endswith((".nxs", "ref_output.txt", ".log"))
]
assert self.ref_nexus_file, "Reference nexus (.nxs) file not found"

assert (
self.nxdl in self.reader.supported_nxdls
Expand Down Expand Up @@ -167,10 +179,17 @@ def convert_to_nexus(

assert test_output == []

# Validate created file using the validate_nexus functionality
def validate_nexus_file(
self,
caplog_level: Literal["ERROR", "WARNING"] = "ERROR",
ignore_undocumented: bool = False,
):
"""Validate the created NeXus using the validate_nexus functionality."""
with self.caplog.at_level(caplog_level):
validate(self.created_nexus, ignore_undocumented=ignore_undocumented)

def parse_nomad(self):
"""Test if the created NeXus file can be parsed by NOMAD."""
if NOMAD_AVAILABLE:
kwargs = dict(
strict=True,
Expand Down Expand Up @@ -304,9 +323,8 @@ def extra_lines(lines1: list[str], lines2: list[str]) -> list[str | None]:
raise AssertionError("\n".join(diffs))

# Load log paths
ref_log_path = get_log_file(self.ref_nexus_file, "ref_nexus.log", self.tmp_path)
gen_log_path = get_log_file(self.created_nexus, "gen_nexus.log", self.tmp_path)
gen_lines, ref_lines = load_logs(gen_log_path, ref_log_path)
gen_lines, ref_lines = load_logs(gen_log_path, self.ref_log_path)

# Compare logs
compare_logs(gen_lines, ref_lines)
Loading
Loading