|
20 | 20 | from collections import defaultdict |
21 | 21 | from functools import reduce |
22 | 22 | from operator import getitem |
23 | | -from typing import Any, Iterable, List, Mapping, Optional, Tuple, Union |
| 23 | +from typing import Any, Iterable, List, Mapping, Optional, Set, Tuple, Union |
24 | 24 |
|
25 | 25 | import h5py |
26 | 26 | import lxml.etree as ET |
27 | 27 | import numpy as np |
28 | 28 | from anytree import Resolver |
| 29 | +from cachetools import LRUCache, cached |
| 30 | +from cachetools.keys import hashkey |
29 | 31 |
|
30 | 32 | from pynxtools.dataconverter.helpers import ( |
31 | 33 | Collector, |
|
42 | 44 | from pynxtools.definitions.dev_tools.utils.nxdl_utils import get_nx_namefit |
43 | 45 |
|
44 | 46 |
|
| 47 | +def best_namefit_of_( |
| 48 | + name: str, concepts: Set[str], nx_class: Optional[str] = None |
| 49 | +) -> str: |
| 50 | + # TODO: Find the best namefit of name in concepts |
| 51 | + # Consider nx_class if it is not None |
| 52 | + ... |
| 53 | + |
| 54 | + |
45 | 55 | def validate_hdf_group_against(appdef: str, data: h5py.Group): |
46 | 56 | """ |
47 | 57 | Checks whether all the required paths from the template are returned in data dict. |
48 | 58 |
|
49 | 59 | THIS IS JUST A FUNCTION SKELETON AND IS NOT WORKING YET! |
50 | 60 | """ |
51 | 61 |
|
52 | | - def validate(name: str, data: Union[h5py.Group, h5py.Dataset]): |
| 62 | + # Only cache based on path. That way we retain the nx_class information |
| 63 | + # in the tree |
| 64 | + # Allow for 10000 cache entries. This should be enough for most cases |
| 65 | + @cached( |
| 66 | + cache=LRUCache(maxsize=10000), |
| 67 | + key=lambda path, _: hashkey(path), |
| 68 | + ) |
| 69 | + def find_node_for(path: str, nx_class: Optional[str] = None) -> Optional[NexusNode]: |
| 70 | + if path == "": |
| 71 | + return tree |
| 72 | + |
| 73 | + prev_path, last_elem = path.rsplit("/", 1) |
| 74 | + node = find_node_for(prev_path) |
| 75 | + |
| 76 | + best_child = best_namefit_of_( |
| 77 | + last_elem, |
| 78 | + # TODO: Consider renaming `get_all_children_names` to |
| 79 | + # `get_all_direct_children_names`. Because that's what it is. |
| 80 | + node.get_all_children_names(), |
| 81 | + nx_class, |
| 82 | + ) |
| 83 | + if best_child is None: |
| 84 | + return None |
| 85 | + |
| 86 | + return node.search_child_with_name(best_child) |
| 87 | + |
| 88 | + def remove_from_req_fields(path: str): |
| 89 | + if path in required_fields: |
| 90 | + required_fields.remove(path) |
| 91 | + |
| 92 | + def handle_group(path: str, data: h5py.Group): |
| 93 | + node = find_node_for(path, data.attrs.get("NX_class")) |
| 94 | + if node is None: |
| 95 | + # TODO: Log undocumented |
| 96 | + return |
| 97 | + |
| 98 | + # TODO: Do actual group checks |
| 99 | + |
| 100 | + def handle_field(path: str, data: h5py.Dataset): |
| 101 | + node = find_node_for(path) |
| 102 | + if node is None: |
| 103 | + # TODO: Log undocumented |
| 104 | + return |
| 105 | + remove_from_req_fields(f"{path}") |
| 106 | + |
| 107 | + # TODO: Do actual field checks |
| 108 | + |
| 109 | + def handle_attributes(path: str, attribute_names: h5py.AttributeManager): |
| 110 | + for attr_name in attribute_names: |
| 111 | + node = find_node_for(f"{path}/{attr_name}") |
| 112 | + if node is None: |
| 113 | + # TODO: Log undocumented |
| 114 | + continue |
| 115 | + remove_from_req_fields(f"{path}/@{attr_name}") |
| 116 | + |
| 117 | + # TODO: Do actual attribute checks |
| 118 | + |
| 119 | + def validate(path: str, data: Union[h5py.Group, h5py.Dataset]): |
53 | 120 | # Namefit name against tree (use recursive caching) |
54 | | - pass |
| 121 | + if isinstance(data, h5py.Group): |
| 122 | + handle_group(path, data) |
| 123 | + elif isinstance(data, h5py.Dataset): |
| 124 | + handle_field(path, data) |
| 125 | + |
| 126 | + handle_attributes(path, data.attrs) |
55 | 127 |
|
56 | 128 | tree = generate_tree_from(appdef) |
| 129 | + required_fields = tree.required_fields_and_attrs_names() |
57 | 130 | data.visitems(validate) |
58 | 131 |
|
| 132 | + for req_field in required_fields: |
| 133 | + if "@" in req_field: |
| 134 | + collector.collect_and_log( |
| 135 | + req_field, ValidationProblem.MissingRequiredAttribute, None |
| 136 | + ) |
| 137 | + continue |
| 138 | + collector.collect_and_log( |
| 139 | + req_field, ValidationProblem.MissingRequiredField, None |
| 140 | + ) |
| 141 | + |
59 | 142 |
|
60 | 143 | def build_nested_dict_from( |
61 | 144 | mapping: Mapping[str, Any], |
|
0 commit comments