Skip to content

Commit 62264c5

Browse files
authored
🐛fix(storage): materialize hub-loaded constants and keep disk memmaps (#345)
1 parent 26719c2 commit 62264c5

File tree

5 files changed

+141
-5
lines changed

5 files changed

+141
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
2525

2626
### Fixes
2727

28+
- (storage/common/reader) clarify constants loading policy: local disk keeps numeric constants as `np.memmap` for efficiency, while Hub metadata loading materializes them to in-memory arrays to avoid temporary-directory lifetime issues.
2829
- (sample) get_global in the scalar case now returns the scalar with the original type.
2930
- (datasets) fix missing location use in get_field_names, and improve corresponding tests.
3031
- (cgns_helpers) update_features_for_CGNS_compatibility: fix behavior where input variable was modified by the function.

docs/source/tutorials/storage.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@ Notes and requirements:
2323
- Parallel mode demonstrates sharding and num_proc controls; writer_batch_size / num_workers are used when uploading to the hub.
2424
- The tutorial includes utilities to inspect CGNS trees and to save single Plaid samples to disk for visualization (e.g., Paraview).
2525

26+
Constants loading policy (metadata):
27+
- Loading metadata from **local disk** keeps numeric constants as file-backed `np.memmap` for memory efficiency on large datasets.
28+
- Loading metadata from the **Hub** materializes numeric constants into in-memory arrays before returning them, to avoid lifetime issues with temporary download folders.
29+
2630
Use these examples as templates to:
2731
- Adapt generators to your raw data format,
2832
- Choose the backend that fits your workflow (hf_datasets for hub integration, cgns for native CGNS interchange, zarr for efficient chunked numeric storage),

src/plaid/storage/cgns/reader.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,9 @@ def sample_generator(
132132
sample = Sample(
133133
path=Path(temp_folder) / "data" / f"{split}" / f"sample_{idx:09d}"
134134
)
135+
# Sample data are eagerly loaded in memory during initialization;
136+
# clear the transient on-disk path before leaving the temp dir.
137+
sample.path = None
135138
yield sample
136139

137140

src/plaid/storage/common/reader.py

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,22 @@
2525

2626
logger = logging.getLogger(__name__)
2727

28+
29+
def _materialize_memmaps(
30+
flat_cst: dict[str, dict[str, Any]],
31+
) -> dict[str, dict[str, Any]]:
32+
"""Return constants with numeric memmaps materialized into in-memory arrays.
33+
34+
This is used for metadata loaded from ephemeral locations (e.g. temporary
35+
download directories), where memmap-backed file lifetimes are not guaranteed.
36+
"""
37+
for split, split_cst in flat_cst.items():
38+
for key, value in split_cst.items():
39+
if isinstance(value, np.memmap):
40+
flat_cst[split][key] = np.asarray(value).copy()
41+
return flat_cst
42+
43+
2844
# ------------------------------------------------------
2945
# Load from disk
3046
# ------------------------------------------------------
@@ -107,8 +123,8 @@ def load_constants_from_disk(path):
107123
Returns:
108124
tuple:
109125
flat_cst (dict[str, dict[str, Any]]): Mapping split -> {constant_name: numpy array | None}.
110-
- Numeric constants are returned as numpy arrays with the dtype and shape specified
111-
in the schema.
126+
- Numeric constants are returned as ``np.memmap`` arrays backed by
127+
``data.mmap`` in the dataset directory.
112128
- String constants are returned as 1-element numpy arrays of Python str decoded using ASCII.
113129
- If layout entry for a key is None, the value is returned as None.
114130
constant_schema (dict[str, dict[str, Any]]): Mapping split -> loaded constant schema (from YAML).
@@ -181,7 +197,8 @@ def load_metadata_from_disk(
181197
182198
Returns:
183199
tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]:
184-
- flat_cst: constant features dictionary
200+
- flat_cst: constant features dictionary (numeric constants kept as
201+
file-backed ``np.memmap``)
185202
- variable_schema: variable schema dictionary
186203
- constant_schema: constant schema dictionary
187204
- cgns_types: CGNS types dictionary
@@ -261,7 +278,8 @@ def load_metadata_from_hub(
261278
262279
Returns:
263280
tuple[dict[str, Any], dict[str, Any], dict[str, Any], dict[str, Any]]:
264-
- flat_cst: constant features dictionary
281+
- flat_cst: constant features dictionary (numeric constants are
282+
materialized to in-memory arrays)
265283
- variable_schema: variable schema dictionary
266284
- constant_schema: constant schema dictionary
267285
- cgns_types: CGNS types dictionary
@@ -275,6 +293,9 @@ def load_metadata_from_hub(
275293
local_dir=temp_folder,
276294
)
277295
flat_cst, constant_schema = load_constants_from_disk(temp_folder)
296+
# Hub metadata is downloaded under a temporary directory: materialize
297+
# memmaps so returned constants remain valid after temp cleanup.
298+
flat_cst = _materialize_memmaps(flat_cst)
278299

279300
# variable_schema
280301
yaml_path = hf_hub_download(

tests/storage/test_storage.py

Lines changed: 108 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,17 @@
77

88
# %% Imports
99

10+
import json
11+
import shutil
1012
from copy import deepcopy
1113
from functools import partial
1214
from pathlib import Path
1315
from typing import Callable
1416

17+
import numpy as np
1518
import pytest
19+
import yaml
1620

17-
# from plaid.bridges import huggingface_bridge
1821
from plaid.containers.dataset import Dataset
1922
from plaid.containers.sample import Sample
2023
from plaid.problem_definition import ProblemDefinition
@@ -28,6 +31,110 @@
2831
)
2932

3033

34+
def test_load_metadata_from_hub_materializes_memmaps(tmp_path, monkeypatch):
35+
"""Hub metadata loader must return arrays independent from temp files."""
36+
from plaid.storage.common import reader as common_reader
37+
38+
repo_root = tmp_path / "fake_hub_repo"
39+
constants_dir = repo_root / "constants" / "train"
40+
constants_dir.mkdir(parents=True)
41+
42+
data = np.arange(6, dtype=np.float32).reshape(2, 3)
43+
with open(constants_dir / "data.mmap", "wb") as f:
44+
f.write(data.tobytes(order="C"))
45+
46+
with open(constants_dir / "layout.json", "w", encoding="utf-8") as f:
47+
json.dump(
48+
{
49+
"Global/cst_numeric": {
50+
"offset": 0,
51+
"shape": list(data.shape),
52+
"dtype": str(data.dtype),
53+
}
54+
},
55+
f,
56+
)
57+
58+
with open(constants_dir / "constant_schema.yaml", "w", encoding="utf-8") as f:
59+
yaml.safe_dump({"Global/cst_numeric": {"dtype": str(data.dtype), "ndim": 2}}, f)
60+
61+
with open(repo_root / "variable_schema.yaml", "w", encoding="utf-8") as f:
62+
yaml.safe_dump({"Global/var": {"dtype": "float32", "ndim": 1}}, f)
63+
64+
with open(repo_root / "cgns_types.yaml", "w", encoding="utf-8") as f:
65+
yaml.safe_dump({"Global": "DataArray_t"}, f)
66+
67+
def _fake_snapshot_download(**kwargs):
68+
local_dir = Path(kwargs["local_dir"])
69+
shutil.copytree(
70+
repo_root / "constants", local_dir / "constants", dirs_exist_ok=True
71+
)
72+
return str(local_dir)
73+
74+
def _fake_hf_hub_download(**kwargs):
75+
return str(repo_root / kwargs["filename"])
76+
77+
monkeypatch.setattr(common_reader, "snapshot_download", _fake_snapshot_download)
78+
monkeypatch.setattr(common_reader, "hf_hub_download", _fake_hf_hub_download)
79+
80+
flat_cst, variable_schema, constant_schema, cgns_types = (
81+
common_reader.load_metadata_from_hub("dummy/repo")
82+
)
83+
84+
loaded = flat_cst["train"]["Global/cst_numeric"]
85+
assert isinstance(loaded, np.ndarray)
86+
assert not isinstance(loaded, np.memmap)
87+
assert np.array_equal(loaded, data)
88+
assert variable_schema["Global/var"]["dtype"] == "float32"
89+
assert "Global/cst_numeric" in constant_schema["train"]
90+
assert cgns_types["Global"] == "DataArray_t"
91+
92+
93+
def test_load_metadata_from_disk_keeps_memmaps(tmp_path):
94+
"""Local metadata loader keeps memmap-backed numeric constants."""
95+
from plaid.storage.common import reader as common_reader
96+
97+
dataset_root = tmp_path / "dataset"
98+
constants_dir = dataset_root / "constants" / "train"
99+
constants_dir.mkdir(parents=True)
100+
101+
data = np.arange(6, dtype=np.float32).reshape(2, 3)
102+
with open(constants_dir / "data.mmap", "wb") as f:
103+
f.write(data.tobytes(order="C"))
104+
105+
with open(constants_dir / "layout.json", "w", encoding="utf-8") as f:
106+
json.dump(
107+
{
108+
"Global/cst_numeric": {
109+
"offset": 0,
110+
"shape": list(data.shape),
111+
"dtype": str(data.dtype),
112+
}
113+
},
114+
f,
115+
)
116+
117+
with open(constants_dir / "constant_schema.yaml", "w", encoding="utf-8") as f:
118+
yaml.safe_dump({"Global/cst_numeric": {"dtype": str(data.dtype), "ndim": 2}}, f)
119+
120+
with open(dataset_root / "variable_schema.yaml", "w", encoding="utf-8") as f:
121+
yaml.safe_dump({"Global/var": {"dtype": "float32", "ndim": 1}}, f)
122+
123+
with open(dataset_root / "cgns_types.yaml", "w", encoding="utf-8") as f:
124+
yaml.safe_dump({"Global": "DataArray_t"}, f)
125+
126+
flat_cst, variable_schema, constant_schema, cgns_types = (
127+
common_reader.load_metadata_from_disk(dataset_root)
128+
)
129+
130+
loaded = flat_cst["train"]["Global/cst_numeric"]
131+
assert isinstance(loaded, np.memmap)
132+
assert np.array_equal(np.asarray(loaded), data)
133+
assert variable_schema["Global/var"]["dtype"] == "float32"
134+
assert "Global/cst_numeric" in constant_schema["train"]
135+
assert cgns_types["Global"] == "DataArray_t"
136+
137+
31138
@pytest.fixture()
32139
def current_directory():
33140
return Path(__file__).absolute().parent

0 commit comments

Comments
 (0)