|
14 | 14 | from typing import ( |
15 | 15 | Any, |
16 | 16 | Optional, |
| 17 | + Union, |
17 | 18 | ) |
18 | 19 |
|
19 | 20 | import numpy as np |
@@ -135,8 +136,7 @@ def __init__( |
135 | 136 | self.shuffle_test = shuffle_test |
136 | 137 | # set modifier |
137 | 138 | self.modifier = modifier |
138 | | - # calculate prefix sum for get_item method |
139 | | - frames_list = [self._get_nframes(item) for item in self.dirs] |
| 139 | + frames_list = [self._get_nframes(set_name) for set_name in self.dirs] |
140 | 140 | self.nframes = np.sum(frames_list) |
141 | 141 | # The prefix sum stores the range of indices contained in each directory, which is needed by get_item method |
142 | 142 | self.prefix_sum = np.cumsum(frames_list).tolist() |
@@ -338,8 +338,10 @@ def get_numb_set(self) -> int: |
338 | 338 |
|
339 | 339 | def get_numb_batch(self, batch_size: int, set_idx: int) -> int: |
340 | 340 | """Get the number of batches in a set.""" |
341 | | - data = self._load_set(self.dirs[set_idx]) |
342 | | - ret = data["coord"].shape[0] // batch_size |
| 341 | + set_name = self.dirs[set_idx] |
| 342 | + # Directly obtain the number of frames to avoid loading the entire dataset |
| 343 | + nframes = self._get_nframes(set_name) |
| 344 | + ret = nframes // batch_size |
343 | 345 | if ret == 0: |
344 | 346 | ret = 1 |
345 | 347 | return ret |
@@ -578,18 +580,27 @@ def _shuffle_data(self, data: dict[str, Any]) -> dict[str, Any]: |
578 | 580 | ret[kk] = data[kk] |
579 | 581 | return ret, idx |
580 | 582 |
|
581 | | - def _get_nframes(self, set_name: DPPath) -> int: |
582 | | - # get nframes |
| 583 | + def _get_nframes(self, set_name: Union[DPPath, str]) -> int: |
583 | 584 | if not isinstance(set_name, DPPath): |
584 | 585 | set_name = DPPath(set_name) |
585 | 586 | path = set_name / "coord.npy" |
586 | | - if self.data_dict["coord"]["high_prec"]: |
587 | | - coord = path.load_numpy().astype(GLOBAL_ENER_FLOAT_PRECISION) |
| 587 | + if isinstance(set_name, DPH5Path): |
| 588 | + nframes = path.root[path._name].shape[0] |
588 | 589 | else: |
589 | | - coord = path.load_numpy().astype(GLOBAL_NP_FLOAT_PRECISION) |
590 | | - if coord.ndim == 1: |
591 | | - coord = coord.reshape([1, -1]) |
592 | | - nframes = coord.shape[0] |
| 590 | + # Read only the header to get shape |
| 591 | + with open(str(path), "rb") as f: |
| 592 | + version = np.lib.format.read_magic(f) |
| 593 | + if version[0] == 1: |
| 594 | + shape, _fortran_order, _dtype = np.lib.format.read_array_header_1_0( |
| 595 | + f |
| 596 | + ) |
| 597 | + elif version[0] in [2, 3]: |
| 598 | + shape, _fortran_order, _dtype = np.lib.format.read_array_header_2_0( |
| 599 | + f |
| 600 | + ) |
| 601 | + else: |
| 602 | + raise ValueError(f"Unsupported .npy file version: {version}") |
| 603 | + nframes = shape[0] if len(shape) > 1 else 1 |
593 | 604 | return nframes |
594 | 605 |
|
595 | 606 | def reformat_data_torch(self, data: dict[str, Any]) -> dict[str, Any]: |
|
0 commit comments