Skip to content

Commit a28929d

Browse files
committed
use new _Cursor to read legacy tracers file
1 parent 0048c9c commit a28929d

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

src/stagpy/stagyyparsers.py

Lines changed: 49 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,36 @@ def refstate(
339339
return syst, adia
340340

341341

342+
@dataclass(frozen=True)
343+
class _Cursor:
344+
fid: BinaryIO
345+
int_type: type[np.integer]
346+
float_type: type[np.floating]
347+
348+
def reset_with_64_bits(self) -> _Cursor:
349+
self.fid.seek(0)
350+
return _Cursor(
351+
fid=self.fid,
352+
int_type=np.int64,
353+
float_type=np.float64,
354+
)
355+
356+
def string(self, nbytes: int) -> str:
357+
return b"".join(np.fromfile(self.fid, "b", nbytes)).strip().decode()
358+
359+
def single_int(self) -> np.integer:
360+
return np.fromfile(self.fid, self.int_type, 1)[0]
361+
362+
def single_float(self) -> np.floating:
363+
return np.fromfile(self.fid, self.float_type, 1)[0]
364+
365+
def ints(self, count: int | np.integer) -> NDArray[np.integer]:
366+
return np.fromfile(self.fid, self.int_type, count)
367+
368+
def floats(self, count: int | np.integer) -> NDArray[np.floating]:
369+
return np.fromfile(self.fid, self.float_type, count)
370+
371+
342372
def _readbin(
343373
fid: BinaryIO,
344374
fmt: str = "i",
@@ -563,7 +593,7 @@ def fields(fieldfile: Path) -> tuple[dict[str, Any], NDArray[np.float64]] | None
563593
return header, flds
564594

565595

566-
def tracers(tracersfile: Path) -> dict[str, list[NDArray[np.float64]]] | None:
596+
def tracers(tracersfile: Path) -> dict[str, list[NDArray[np.floating]]] | None:
567597
"""Extract tracers data.
568598
569599
Args:
@@ -574,38 +604,39 @@ def tracers(tracersfile: Path) -> dict[str, list[NDArray[np.float64]]] | None:
574604
"""
575605
if not tracersfile.is_file():
576606
return None
577-
tra: dict[str, list[NDArray[np.float64]]] = {}
607+
tra: dict[str, list[NDArray[np.floating]]] = {}
578608
with tracersfile.open("rb") as fid:
579-
readbin = partial(_readbin, fid)
580-
magic = readbin()
609+
cursor = _Cursor(fid=fid, int_type=np.int32, float_type=np.float32)
610+
magic = cursor.single_int().item()
581611
if magic > 8000: # 64 bits
612+
cursor = cursor.reset_with_64_bits()
613+
if magic != cursor.single_int():
614+
raise ParsingError(tracersfile, "inconsistent magic number in 64 bits")
582615
magic -= 8000
583-
readbin()
584-
readbin = partial(readbin, file64=True)
585616
if magic < 100:
586617
raise ParsingError(
587618
tracersfile, "magic > 100 expected to get tracervar info"
588619
)
589620
nblk = magic % 100
590-
readbin("f", 2) # aspect ratio
591-
readbin() # istep
592-
readbin("f") # time
593-
ninfo = readbin()
594-
ntra = readbin(nwords=nblk, unpack=False)
595-
readbin("f") # tracer ideal mass
596-
curv = readbin()
621+
cursor.floats(2) # aspect ratio
622+
cursor.single_int() # istep
623+
cursor.single_float() # time
624+
ninfo = cursor.single_int()
625+
ntra = cursor.ints(nblk)
626+
cursor.single_float() # tracer ideal mass
627+
curv = cursor.single_int()
597628
if curv:
598-
readbin("f") # r_cmb
629+
cursor.single_float() # r_cmb
599630
infos = [] # list of info names
600631
for _ in range(ninfo):
601-
infos.append(b"".join(readbin("b", 16)).strip().decode())
632+
infos.append(cursor.string(16))
602633
tra[infos[-1]] = []
603634
if magic > 200:
604-
ntrace_elt = readbin()
635+
ntrace_elt = cursor.single_int()
605636
if ntrace_elt > 0:
606-
readbin("f", ntrace_elt) # outgassed
637+
cursor.floats(ntrace_elt) # outgassed
607638
for ntrab in ntra: # blocks
608-
data = readbin("f", ntrab * ninfo)
639+
data = cursor.floats(ntrab * ninfo)
609640
for idx, info in enumerate(infos):
610641
tra[info].append(data[idx::ninfo])
611642
return tra

src/stagpy/step.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -394,10 +394,10 @@ class Tracers:
394394
step: Step
395395

396396
@cached_property
397-
def _data(self) -> dict[str, list[NDArray[np.float64]] | None]:
397+
def _data(self) -> dict[str, list[NDArray[np.floating]] | None]:
398398
return {}
399399

400-
def __getitem__(self, name: str) -> list[NDArray[np.float64]] | None:
400+
def __getitem__(self, name: str) -> list[NDArray[np.floating]] | None:
401401
if name in self._data:
402402
return self._data[name]
403403
if self.step.isnap is None:
@@ -406,7 +406,7 @@ def __getitem__(self, name: str) -> list[NDArray[np.float64]] | None:
406406
self.step.sdat.filename("tra", timestep=self.step.isnap, force_legacy=True)
407407
)
408408
if data is None and self.step.sdat.hdf5:
409-
self._data[name] = stagyyparsers.read_tracers_h5(
409+
self._data[name] = stagyyparsers.read_tracers_h5( # type: ignore
410410
self.step.sdat._traxmf,
411411
name,
412412
self.step.isnap,

0 commit comments

Comments
 (0)