Skip to content

Commit 882789b

Browse files
committed
use _Cursor to parse legacy fields
1 parent a28929d commit 882789b

File tree

1 file changed

+31
-53
lines changed

1 file changed

+31
-53
lines changed

src/stagpy/stagyyparsers.py

Lines changed: 31 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import re
1212
import typing
1313
from dataclasses import dataclass
14-
from functools import cached_property, partial
14+
from functools import cached_property
1515
from itertools import product
1616
from operator import itemgetter
1717

@@ -26,7 +26,7 @@
2626
if typing.TYPE_CHECKING:
2727
from collections.abc import Iterator, Mapping
2828
from pathlib import Path
29-
from typing import Any, BinaryIO, Callable
29+
from typing import Any, BinaryIO
3030
from xml.etree.ElementTree import Element
3131

3232
from numpy.typing import NDArray
@@ -369,51 +369,28 @@ def floats(self, count: int | np.integer) -> NDArray[np.floating]:
369369
return np.fromfile(self.fid, self.float_type, count)
370370

371371

372-
def _readbin(
373-
fid: BinaryIO,
374-
fmt: str = "i",
375-
nwords: int = 1,
376-
file64: bool = False,
377-
unpack: bool = True,
378-
) -> Any:
379-
"""Read n words of 4 or 8 bytes with fmt format.
380-
381-
fmt: 'i' or 'f' or 'b' (integer or float or bytes)
382-
4 or 8 bytes: depends on header
383-
384-
Return an array of elements if more than one element.
385-
386-
Default: read 1 word formatted as an integer.
387-
"""
388-
if fmt in "if":
389-
fmt += "8" if file64 else "4"
390-
elts = np.fromfile(fid, fmt, nwords)
391-
if unpack and len(elts) == 1:
392-
elts = elts[0]
393-
return elts
394-
395-
396372
@dataclass(frozen=True)
397373
class _HeaderInfo:
398374
"""Header information."""
399375

400376
magic: int
401377
nval: int
402378
sfield: bool
403-
readbin: Callable
379+
cursor: _Cursor
404380
header: dict[str, Any]
405381

406382

407383
def _legacy_header(
408384
filepath: Path, fid: BinaryIO, stop_at_istep: bool = False
409385
) -> _HeaderInfo:
410386
"""Read the header of a legacy binary file."""
411-
readbin = partial(_readbin, fid)
412-
magic = readbin()
387+
cursor = _Cursor(fid=fid, int_type=np.int32, float_type=np.float32)
388+
magic = cursor.single_int().item()
413389
if magic > 8000: # 64 bits
390+
cursor = cursor.reset_with_64_bits()
391+
if magic != cursor.single_int():
392+
raise ParsingError(filepath, "inconsistent magic number in 64 bits")
414393
magic -= 8000
415-
readbin() # need to read 4 more bytes
416-
readbin = partial(readbin, file64=True)
417394

418395
# check nb components
419396
nval = 1
@@ -429,56 +406,56 @@ def _legacy_header(
429406
if magic < 9 or magic > 12:
430407
raise ParsingError(filepath, f"{magic=:d} not supported")
431408

432-
header_info = _HeaderInfo(magic, nval, sfield, readbin, {})
409+
header_info = _HeaderInfo(magic, nval, sfield, cursor, {})
433410
header = header_info.header
434411
# extra ghost point in horizontal direction
435412
header["xyp"] = int(nval == 4) # magic >= 9
436413

437414
# total number of values in relevant space basis
438415
# (e1, e2, e3) = (theta, phi, radius) in spherical geometry
439416
# = (x, y, z) in cartesian geometry
440-
header["nts"] = readbin(nwords=3)
417+
header["nts"] = cursor.ints(3)
441418

442419
# number of blocks, 2 for yinyang or cubed sphere
443-
header["ntb"] = readbin() # magic >= 7
420+
header["ntb"] = cursor.single_int() # magic >= 7
444421

445422
# aspect ratio
446-
header["aspect"] = readbin("f", 2)
423+
header["aspect"] = cursor.floats(2)
447424

448425
# number of parallel subdomains
449-
header["ncs"] = readbin(nwords=3) # (e1, e2, e3) space
450-
header["ncb"] = readbin() # magic >= 8, blocks
426+
header["ncs"] = cursor.ints(3) # (e1, e2, e3) space
427+
header["ncb"] = cursor.single_int() # magic >= 8, blocks
451428

452429
# r - coordinates
453430
# rgeom[0:self.nrtot+1, 0] are edge radial position
454431
# rgeom[0:self.nrtot, 1] are cell-center radial position
455-
header["rgeom"] = readbin("f", header["nts"][2] * 2 + 1) # magic >= 2
432+
header["rgeom"] = cursor.floats(header["nts"][2] * 2 + 1) # magic >= 2
456433
header["rgeom"] = np.resize(header["rgeom"], (header["nts"][2] + 1, 2))
457434

458-
header["rcmb"] = readbin("f") # magic >= 7
435+
header["rcmb"] = cursor.single_float() # magic >= 7
459436

460-
header["ti_step"] = readbin() # magic >= 3
437+
header["ti_step"] = cursor.single_int() # magic >= 3
461438
if stop_at_istep:
462439
return header_info
463440

464-
header["ti_ad"] = readbin("f") # magic >= 3
465-
header["erupta_total"] = readbin("f") # magic >= 5
441+
header["ti_ad"] = cursor.single_float() # magic >= 3
442+
header["erupta_total"] = cursor.single_float() # magic >= 5
466443
if magic >= 12:
467-
header["erupta_ttg"] = readbin("f")
468-
header["intruda"] = readbin("f", 2)
469-
header["ttg_mass"] = readbin("f", 3)
444+
header["erupta_ttg"] = cursor.single_float()
445+
header["intruda"] = cursor.floats(2)
446+
header["ttg_mass"] = cursor.floats(3)
470447
else:
471448
header["erupta_ttg"] = 0.0
472449
header["intruda"] = np.zeros(2)
473450
header["ttg_mass"] = np.zeros(3)
474-
header["bot_temp"] = readbin("f") # magic >= 6
475-
header["core_temp"] = readbin("f") if magic >= 10 else 1
476-
header["ocean_mass"] = readbin("f") if magic >= 11 else 0.0
451+
header["bot_temp"] = cursor.single_float() # magic >= 6
452+
header["core_temp"] = cursor.single_float() if magic >= 10 else 1.0
453+
header["ocean_mass"] = cursor.single_float() if magic >= 11 else 0.0
477454

478455
# magic >= 4
479-
header["e1_coord"] = readbin("f", header["nts"][0])
480-
header["e2_coord"] = readbin("f", header["nts"][1])
481-
header["e3_coord"] = readbin("f", header["nts"][2])
456+
header["e1_coord"] = cursor.floats(header["nts"][0])
457+
header["e2_coord"] = cursor.floats(header["nts"][1])
458+
header["e3_coord"] = cursor.floats(header["nts"][2])
482459

483460
return header_info
484461

@@ -530,6 +507,7 @@ def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray[np.float64]] | None
530507
with fieldfile.open("rb") as fid:
531508
hdr = _legacy_header(fieldfile, fid)
532509
header = hdr.header
510+
cursor = hdr.cursor
533511

534512
# READ FIELDS
535513
# number of points in (e1, e2, e3) directions PER CPU
@@ -545,7 +523,7 @@ def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray[np.float64]] | None
545523
* hdr.nval
546524
)
547525

548-
header["scalefac"] = hdr.readbin("f") if hdr.nval > 1 else 1
526+
header["scalefac"] = cursor.single_float() if hdr.nval > 1 else 1.0
549527

550528
flds = np.zeros(
551529
(
@@ -565,7 +543,7 @@ def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray[np.float64]] | None
565543
range(header["ncs"][0]),
566544
):
567545
# read the data for one CPU
568-
data_cpu = hdr.readbin("f", npi) * header["scalefac"]
546+
data_cpu = cursor.floats(npi) * header["scalefac"]
569547

570548
# icpu is (icpu block, icpu z, icpu y, icpu x)
571549
# data from file is transposed to obtained a field

0 commit comments

Comments
 (0)