1111import re
1212import typing
1313from dataclasses import dataclass
14- from functools import partial
14+ from functools import cached_property , partial
1515from itertools import product
1616from operator import itemgetter
1717from xml .etree import ElementTree as xmlET
@@ -855,7 +855,16 @@ def _maybe_get(
855855 return maybe_info
856856
857857
858- def read_geom_h5 (xdmf_file : Path , snapshot : int ) -> Tuple [Dict [str , Any ], Element ]:
858+ @dataclass (frozen = True )
859+ class FieldXmf :
860+ path : Path
861+
862+ @cached_property
863+ def root (self ) -> Element :
864+ return xmlET .parse (str (self .path )).getroot ()
865+
866+
867+ def read_geom_h5 (xdmf : FieldXmf , snapshot : int ) -> Tuple [Dict [str , Any ], Element ]:
859868 """Extract geometry information from hdf5 files.
860869
861870 Args:
@@ -865,15 +874,15 @@ def read_geom_h5(xdmf_file: Path, snapshot: int) -> Tuple[Dict[str, Any], Elemen
865874 geometry information and root of xdmf document.
866875 """
867876 header : Dict [str , Any ] = {}
868- xdmf_root = xmlET . parse ( str ( xdmf_file )). getroot ()
877+ xdmf_root = xdmf . root
869878 if snapshot is None :
870879 return {}, xdmf_root
871880
872881 # Domain, Temporal Collection, Snapshot
873882 # should check that this is indeed the required snapshot
874883 elt_snap = xdmf_root [0 ][0 ][snapshot ]
875884 if elt_snap is None :
876- raise ParsingError (xdmf_file , f"Snapshot { snapshot } not present" )
885+ raise ParsingError (xdmf . path , f"Snapshot { snapshot } not present" )
877886 header ["ti_ad" ] = _maybe_get (elt_snap , "Time" , "Value" , float )
878887 header ["mo_lambda" ] = _maybe_get (elt_snap , "mo_lambda" , "Value" , float )
879888 header ["mo_thick_sol" ] = _maybe_get (elt_snap , "mo_thick_sol" , "Value" , float )
@@ -882,21 +891,21 @@ def read_geom_h5(xdmf_file: Path, snapshot: int) -> Tuple[Dict[str, Any], Elemen
882891 coord_shape = [] # shape of meshes
883892 twod = None
884893 for elt_subdomain in elt_snap .findall ("Grid" ):
885- elt_name = _try_get (xdmf_file , elt_subdomain , "Name" )
894+ elt_name = _try_get (xdmf . path , elt_subdomain , "Name" )
886895 if elt_name .startswith ("meshYang" ):
887896 header ["ntb" ] = 2
888897 break # iterate only through meshYin
889- elt_geom = _try_find (xdmf_file , elt_subdomain , "Geometry" )
898+ elt_geom = _try_find (xdmf . path , elt_subdomain , "Geometry" )
890899 if elt_geom .get ("Type" ) == "X_Y" and twod is None :
891900 twod = ""
892901 for data_item in elt_geom .findall ("DataItem" ):
893- coord = _try_text (xdmf_file , data_item ).strip ()[- 1 ]
902+ coord = _try_text (xdmf . path , data_item ).strip ()[- 1 ]
894903 if coord in "XYZ" :
895904 twod += coord
896- data_item = _try_find (xdmf_file , elt_geom , "DataItem" )
897- data_text = _try_text (xdmf_file , data_item )
898- coord_shape .append (_get_dim (xdmf_file , data_item ))
899- coord_h5 .append (xdmf_file .parent / data_text .strip ().split (":/" , 1 )[0 ])
905+ data_item = _try_find (xdmf . path , elt_geom , "DataItem" )
906+ data_text = _try_text (xdmf . path , data_item )
907+ coord_shape .append (_get_dim (xdmf . path , data_item ))
908+ coord_h5 .append (xdmf . path .parent / data_text .strip ().split (":/" , 1 )[0 ])
900909 _read_coord_h5 (coord_h5 , coord_shape , header , twod )
901910 return header , xdmf_root
902911
@@ -952,7 +961,7 @@ def _post_read_flds(flds: ndarray, header: Dict[str, Any]) -> ndarray:
952961
953962
954963def read_field_h5 (
955- xdmf_file : Path ,
964+ xdmf : FieldXmf ,
956965 fieldname : str ,
957966 snapshot : int ,
958967 header : Optional [Dict [str , Any ]] = None ,
@@ -969,22 +978,22 @@ def read_field_h5(
969978 unavailable.
970979 """
971980 if header is None :
972- header , xdmf_root = read_geom_h5 (xdmf_file , snapshot )
981+ header , xdmf_root = read_geom_h5 (xdmf , snapshot )
973982 else :
974- xdmf_root = xmlET . parse ( str ( xdmf_file )). getroot ()
983+ xdmf_root = xdmf . root
975984
976985 npc = header ["nts" ] // header ["ncs" ] # number of grid point per node
977986 flds = np .zeros (_flds_shape (fieldname , header ))
978987 data_found = False
979988
980989 for elt_subdomain in xdmf_root [0 ][0 ][snapshot ].findall ("Grid" ):
981- elt_name = _try_get (xdmf_file , elt_subdomain , "Name" )
990+ elt_name = _try_get (xdmf . path , elt_subdomain , "Name" )
982991 ibk = int (elt_name .startswith ("meshYang" ))
983992 for data_attr in elt_subdomain .findall ("Attribute" ):
984993 if data_attr .get ("Name" ) != fieldname :
985994 continue
986995 icore , fld = _get_field (
987- xdmf_file , _try_find (xdmf_file , data_attr , "DataItem" )
996+ xdmf . path , _try_find (xdmf . path , data_attr , "DataItem" )
988997 )
989998 # for some reason, the field is transposed
990999 fld = fld .T
0 commit comments