@@ -477,7 +477,7 @@ def field_header(fieldfile: Path) -> dict[str, Any] | None:
477477 return hdr .header
478478
479479
480- def fields (fieldfile : Path ) -> tuple [dict [str , Any ], NDArray ] | None :
480+ def fields (fieldfile : Path ) -> tuple [dict [str , Any ], NDArray [ np . float64 ] ] | None :
481481 """Extract fields data.
482482
483483 Args:
@@ -555,7 +555,7 @@ def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray] | None:
555555 return header , flds
556556
557557
558- def tracers (tracersfile : Path ) -> dict [str , list [NDArray ]] | None :
558+ def tracers (tracersfile : Path ) -> dict [str , list [NDArray [ np . float64 ] ]] | None :
559559 """Extract tracers data.
560560
561561 Args:
@@ -566,7 +566,7 @@ def tracers(tracersfile: Path) -> dict[str, list[NDArray]] | None:
566566 """
567567 if not tracersfile .is_file ():
568568 return None
569- tra : dict [str , list [NDArray ]] = {}
569+ tra : dict [str , list [NDArray [ np . float64 ] ]] = {}
570570 with tracersfile .open ("rb" ) as fid :
571571 readbin = partial (_readbin , fid )
572572 magic = readbin ()
@@ -603,7 +603,7 @@ def tracers(tracersfile: Path) -> dict[str, list[NDArray]] | None:
603603 return tra
604604
605605
606- def _read_group_h5 (filename : Path , groupname : str ) -> NDArray :
606+ def _read_group_h5 (filename : Path , groupname : str ) -> NDArray [ np . float64 ] :
607607 """Return group content.
608608
609609 Args:
@@ -623,7 +623,7 @@ def _read_group_h5(filename: Path, groupname: str) -> NDArray:
623623 return data # need to be reshaped
624624
625625
626- def _make_3d (field : NDArray , twod : str | None ) -> NDArray :
626+ def _make_3d (field : NDArray [ np . float64 ] , twod : str | None ) -> NDArray [ np . float64 ] :
627627 """Add a dimension to field if necessary.
628628
629629 Args:
@@ -641,7 +641,9 @@ def _make_3d(field: NDArray, twod: str | None) -> NDArray:
641641 return field .reshape (shp )
642642
643643
644- def _ncores (meshes : list [dict [str , NDArray ]], twod : str | None ) -> NDArray :
644+ def _ncores (
645+ meshes : list [dict [str , NDArray [np .float64 ]]], twod : str | None
646+ ) -> NDArray [np .float64 ]:
645647 """Compute number of nodes in each direction."""
646648 nnpb = len (meshes ) # number of nodes per block
647649 nns = [1 , 1 , 1 ] # number of nodes in x, y, z directions
@@ -672,8 +674,8 @@ def _ncores(meshes: list[dict[str, NDArray]], twod: str | None) -> NDArray:
672674
673675
674676def _conglomerate_meshes (
675- meshin : list [dict [str , NDArray ]], header : dict [str , Any ]
676- ) -> dict [str , NDArray ]:
677+ meshin : list [dict [str , NDArray [ np . float64 ] ]], header : dict [str , Any ]
678+ ) -> dict [str , NDArray [ np . float64 ] ]:
677679 """Conglomerate meshes from several cores into one."""
678680 meshout = {}
679681 npc = header ["nts" ] // header ["ncs" ]
@@ -870,7 +872,7 @@ def read_geom_h5(xdmf: FieldXmf, snapshot: int) -> dict[str, Any]:
870872 header ["mo_thick_sol" ] = entry .mo_thick_sol
871873 header ["ntb" ] = 2 if entry .yin_yang else 1
872874
873- all_meshes : list [dict [str , NDArray ]] = []
875+ all_meshes : list [dict [str , NDArray [ np . float64 ] ]] = []
874876 for h5file in entry .coord_files_yin (xdmf .path .parent ):
875877 all_meshes .append ({})
876878 with h5py .File (h5file , "r" ) as h5f :
@@ -927,7 +929,9 @@ def read_geom_h5(xdmf: FieldXmf, snapshot: int) -> dict[str, Any]:
927929 return header
928930
929931
930- def _to_spherical (flds : NDArray , header : dict [str , Any ]) -> NDArray :
932+ def _to_spherical (
933+ flds : NDArray [np .float64 ], header : dict [str , Any ]
934+ ) -> NDArray [np .float64 ]:
931935 """Convert vector field to spherical."""
932936 cth = np .cos (header ["t_mesh" ][:, :, :- 1 ])
933937 sth = np .sin (header ["t_mesh" ][:, :, :- 1 ])
@@ -962,7 +966,9 @@ def _flds_shape(fieldname: str, header: dict[str, Any]) -> list[int]:
962966 return shp
963967
964968
965- def _post_read_flds (flds : NDArray , header : dict [str , Any ]) -> NDArray :
969+ def _post_read_flds (
970+ flds : NDArray [np .float64 ], header : dict [str , Any ]
971+ ) -> NDArray [np .float64 ]:
966972 """Process flds to handle sphericity."""
967973 if flds .shape [0 ] >= 3 and header ["rcmb" ] > 0 :
968974 # spherical vector
@@ -982,7 +988,7 @@ def read_field_h5(
982988 fieldname : str ,
983989 snapshot : int ,
984990 header : dict [str , Any ] | None = None ,
985- ) -> tuple [dict [str , Any ], NDArray ] | None :
991+ ) -> tuple [dict [str , Any ], NDArray [ np . float64 ] ] | None :
986992 """Extract field data from hdf5 files.
987993
988994 Args:
@@ -1146,7 +1152,9 @@ def __getitem__(self, isnap: int) -> XmfTracersEntry:
11461152 raise ParsingError (self .path , f"no data for snapshot { isnap } " )
11471153
11481154
1149- def read_tracers_h5 (xdmf : TracersXmf , infoname : str , snapshot : int ) -> list [NDArray ]:
1155+ def read_tracers_h5 (
1156+ xdmf : TracersXmf , infoname : str , snapshot : int
1157+ ) -> list [NDArray [np .float64 ]]:
11501158 """Extract tracers data from hdf5 files.
11511159
11521160 Args:
@@ -1157,11 +1165,11 @@ def read_tracers_h5(xdmf: TracersXmf, infoname: str, snapshot: int) -> list[NDAr
11571165 Returns:
11581166 Tracers data organized by attribute and block.
11591167 """
1160- tra : list [list [NDArray ]] = [[], []] # [block][core]
1168+ tra : list [list [NDArray [ np . float64 ] ]] = [[], []] # [block][core]
11611169 for tsub in xdmf [snapshot ].tra_subdomains (xdmf .path .parent , infoname ):
11621170 tra [tsub .iblock ].append (_read_group_h5 (tsub .file , tsub .dataset ))
11631171
1164- tra_concat : list [NDArray ] = []
1172+ tra_concat : list [NDArray [ np . float64 ] ] = []
11651173 for trab in tra :
11661174 if trab :
11671175 tra_concat .append (np .concatenate (trab ))
0 commit comments