Skip to content

Commit 864d219

Browse files
committed
inline _read_coord_h5
1 parent b7f9727 commit 864d219

File tree

1 file changed

+56
-70
lines changed

1 file changed

+56
-70
lines changed

stagpy/stagyyparsers.py

Lines changed: 56 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from xml.etree.ElementTree import Element
4141

4242
from numpy import ndarray
43+
from numpy.typing import NDArray
4344
from pandas import DataFrame
4445

4546
T = TypeVar("T")
@@ -718,75 +719,6 @@ def _conglomerate_meshes(
718719
return meshout
719720

720721

721-
def _read_coord_h5(
722-
files: List[Path],
723-
shapes: List[Tuple[int, ...]],
724-
header: Dict[str, Any],
725-
twod: Optional[str],
726-
) -> None:
727-
"""Read all coord hdf5 files of a snapshot.
728-
729-
Args:
730-
files: list of NodeCoordinates files of a snapshot.
731-
shapes: shape of mesh grids.
732-
header: geometry info.
733-
twod: 'XZ', 'YZ' or None depending on what is relevant.
734-
"""
735-
all_meshes: List[Dict[str, ndarray]] = []
736-
for h5file, shape in zip(files, shapes):
737-
all_meshes.append({})
738-
with h5py.File(h5file, "r") as h5f:
739-
for coord, mesh in h5f.items():
740-
# for some reason, the array is transposed!
741-
all_meshes[-1][coord] = mesh[()].reshape(shape).T
742-
all_meshes[-1][coord] = _make_3d(all_meshes[-1][coord], twod)
743-
744-
header["ncs"] = _ncores(all_meshes, twod)
745-
header["nts"] = list(
746-
(all_meshes[0]["X"].shape[i] - 1) * header["ncs"][i] for i in range(3)
747-
)
748-
header["nts"] = np.array([max(1, val) for val in header["nts"]])
749-
meshes = _conglomerate_meshes(all_meshes, header)
750-
if np.any(meshes["Z"][:, :, 0] != 0):
751-
# spherical
752-
if twod is not None: # annulus geometry...
753-
header["x_mesh"] = np.copy(meshes["Y"])
754-
header["y_mesh"] = np.copy(meshes["Z"])
755-
header["z_mesh"] = np.copy(meshes["X"])
756-
else: # YinYang, here only yin
757-
header["x_mesh"] = np.copy(meshes["X"])
758-
header["y_mesh"] = np.copy(meshes["Y"])
759-
header["z_mesh"] = np.copy(meshes["Z"])
760-
header["r_mesh"] = np.sqrt(
761-
header["x_mesh"] ** 2 + header["y_mesh"] ** 2 + header["z_mesh"] ** 2
762-
)
763-
header["t_mesh"] = np.arccos(header["z_mesh"] / header["r_mesh"])
764-
header["p_mesh"] = np.roll(
765-
np.arctan2(header["y_mesh"], -header["x_mesh"]) + np.pi, -1, 1
766-
)
767-
header["e1_coord"] = header["t_mesh"][:, 0, 0]
768-
header["e2_coord"] = header["p_mesh"][0, :, 0]
769-
header["e3_coord"] = header["r_mesh"][0, 0, :]
770-
else:
771-
header["e1_coord"] = meshes["X"][:, 0, 0]
772-
header["e2_coord"] = meshes["Y"][0, :, 0]
773-
header["e3_coord"] = meshes["Z"][0, 0, :]
774-
header["aspect"] = (
775-
header["e1_coord"][-1] - header["e2_coord"][0],
776-
header["e1_coord"][-1] - header["e2_coord"][0],
777-
)
778-
header["rcmb"] = header["e3_coord"][0]
779-
if header["rcmb"] == 0:
780-
header["rcmb"] = -1
781-
else:
782-
header["e3_coord"] = header["e3_coord"] - header["rcmb"]
783-
if twod is None or "X" in twod:
784-
header["e1_coord"] = header["e1_coord"][:-1]
785-
if twod is None or "Y" in twod:
786-
header["e2_coord"] = header["e2_coord"][:-1]
787-
header["e3_coord"] = header["e3_coord"][:-1]
788-
789-
790722
def _try_get(file: Path, elt: Element, key: str) -> str:
791723
"""Try getting an attribute or raise a ParsingError."""
792724
att = elt.get(key)
@@ -959,7 +891,61 @@ def read_geom_h5(xdmf: FieldXmf, snapshot: int) -> dict[str, Any]:
959891
header["mo_lambda"] = entry.mo_lambda
960892
header["mo_thick_sol"] = entry.mo_thick_sol
961893
header["ntb"] = 2 if entry.yin_yang else 1
962-
_read_coord_h5(entry.coord_h5, entry.coord_shape, header, entry.twod)
894+
895+
all_meshes: list[dict[str, NDArray]] = []
896+
for h5file, shape in zip(entry.coord_h5, entry.coord_shape):
897+
all_meshes.append({})
898+
with h5py.File(h5file, "r") as h5f:
899+
for coord, mesh in h5f.items():
900+
# for some reason, the array is transposed!
901+
all_meshes[-1][coord] = mesh[()].reshape(shape).T
902+
all_meshes[-1][coord] = _make_3d(all_meshes[-1][coord], entry.twod)
903+
904+
header["ncs"] = _ncores(all_meshes, entry.twod)
905+
header["nts"] = list(
906+
(all_meshes[0]["X"].shape[i] - 1) * header["ncs"][i] for i in range(3)
907+
)
908+
header["nts"] = np.array([max(1, val) for val in header["nts"]])
909+
meshes = _conglomerate_meshes(all_meshes, header)
910+
if np.any(meshes["Z"][:, :, 0] != 0):
911+
# spherical
912+
if entry.twod is not None: # annulus geometry...
913+
header["x_mesh"] = np.copy(meshes["Y"])
914+
header["y_mesh"] = np.copy(meshes["Z"])
915+
header["z_mesh"] = np.copy(meshes["X"])
916+
else: # YinYang, here only yin
917+
header["x_mesh"] = np.copy(meshes["X"])
918+
header["y_mesh"] = np.copy(meshes["Y"])
919+
header["z_mesh"] = np.copy(meshes["Z"])
920+
header["r_mesh"] = np.sqrt(
921+
header["x_mesh"] ** 2 + header["y_mesh"] ** 2 + header["z_mesh"] ** 2
922+
)
923+
header["t_mesh"] = np.arccos(header["z_mesh"] / header["r_mesh"])
924+
header["p_mesh"] = np.roll(
925+
np.arctan2(header["y_mesh"], -header["x_mesh"]) + np.pi, -1, 1
926+
)
927+
header["e1_coord"] = header["t_mesh"][:, 0, 0]
928+
header["e2_coord"] = header["p_mesh"][0, :, 0]
929+
header["e3_coord"] = header["r_mesh"][0, 0, :]
930+
else:
931+
header["e1_coord"] = meshes["X"][:, 0, 0]
932+
header["e2_coord"] = meshes["Y"][0, :, 0]
933+
header["e3_coord"] = meshes["Z"][0, 0, :]
934+
header["aspect"] = (
935+
header["e1_coord"][-1] - header["e2_coord"][0],
936+
header["e1_coord"][-1] - header["e2_coord"][0],
937+
)
938+
header["rcmb"] = header["e3_coord"][0]
939+
if header["rcmb"] == 0:
940+
header["rcmb"] = -1
941+
else:
942+
header["e3_coord"] = header["e3_coord"] - header["rcmb"]
943+
if entry.twod is None or "X" in entry.twod:
944+
header["e1_coord"] = header["e1_coord"][:-1]
945+
if entry.twod is None or "Y" in entry.twod:
946+
header["e2_coord"] = header["e2_coord"][:-1]
947+
header["e3_coord"] = header["e3_coord"][:-1]
948+
963949
return header
964950

965951

0 commit comments

Comments
 (0)