Skip to content

Commit 66a2835

Browse files
committed
address revisions
1 parent f7203d1 commit 66a2835

File tree

2 files changed

+136
-67
lines changed

2 files changed

+136
-67
lines changed

tests/test_data_loading.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import pandas as pd
1414
import zipfile
1515

16-
from wsimod.orchestration.model import Model, PANDAS_AVAILABLE
16+
from wsimod.orchestration.model import Model, PARQUET_AVAILABLE
1717

1818

1919
def _unzip_model_data(temp_dir: str):
@@ -54,6 +54,48 @@ def test_load_save(self):
5454
df2.time = df2.time.astype(str)
5555
pd.testing.assert_frame_equal(df, df2)
5656

57+
def test_misc_load_save(self):
58+
"""Test miscellaneous load and save functionality."""
59+
60+
from wsimod.orchestration.model import Model
61+
62+
model = Model()
63+
model.add_nodes(
64+
[
65+
{
66+
"name": "my-land2",
67+
"type_": "Land",
68+
"surfaces": [
69+
{
70+
"surface": "my_surface",
71+
"area": 1,
72+
"type_": "GrowingSurface",
73+
"initial_storage": {
74+
"phosphate": 1,
75+
"nitrate": 2,
76+
"nitrite": 3,
77+
"ammonia": 4,
78+
"org-nitrogen": 5,
79+
"org-phosphorus": 6,
80+
},
81+
},
82+
],
83+
}
84+
]
85+
)
86+
with tempfile.TemporaryDirectory() as temp_dir:
87+
model.save(temp_dir)
88+
model = Model()
89+
model.load(temp_dir)
90+
assert model.nodes["my-land2"].surfaces[0].initial_storage == {
91+
"phosphate": 1,
92+
"nitrate": 2,
93+
"nitrite": 3,
94+
"ammonia": 4,
95+
"org-nitrogen": 5,
96+
"org-phosphorus": 6,
97+
}
98+
5799
def test_performance_comparison(self):
58100
"""Compare performance of original vs unified data save/load/run."""
59101
with tempfile.TemporaryDirectory() as temp_dir:

wsimod/orchestration/model.py

Lines changed: 93 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
pyarrow_available = importlib.util.find_spec("pyarrow") is not None
2525
fastparquet_available = importlib.util.find_spec("fastparquet") is not None
2626

27-
PANDAS_AVAILABLE = pyarrow_available or fastparquet_available
27+
PARQUET_AVAILABLE = pyarrow_available or fastparquet_available
2828
except ImportError:
29-
PANDAS_AVAILABLE = False
29+
PARQUET_AVAILABLE = False
3030

3131
from wsimod.arcs import arcs as arcs_mod
3232
from wsimod.core import constants
@@ -259,66 +259,10 @@ def load(self, address, config_name="config.yml", overrides={}):
259259

260260
# Check if using unified data file
261261
unified_data_file = data.get("unified_data_file")
262-
if unified_data_file and PANDAS_AVAILABLE:
263-
264-
# Load unified data
265-
unified_data_path = os.path.join(address, unified_data_file)
266-
self.unified_data = pd.read_parquet(unified_data_path)
267-
268-
# Create a single comprehensive data dictionary for all nodes
269-
# This is much faster than individual lookups
270-
surface_data = self.unified_data.dropna(subset=["surface"]).copy()
271-
surface_data["time"] = pd.to_datetime(surface_data["time"]).dt.to_period(
272-
"M"
273-
)
274-
surface_dict = (
275-
surface_data.groupby(["node", "surface"])
276-
.apply(lambda x: x.set_index(["variable", "time"]).value.to_dict())
277-
.to_dict()
278-
)
279-
280-
node_data = self.unified_data.loc[self.unified_data.surface.isna()].copy()
281-
node_data["time"] = pd.to_datetime(node_data["time"])
282-
node_dict = (
283-
node_data.groupby("node")
284-
.apply(lambda x: x.set_index(["variable", "time"]).value.to_dict())
285-
.to_dict()
286-
)
287-
288-
# Assign the same comprehensive data dict to all nodes that need it
289-
for name, node in nodes.items():
290-
if "data_input_dict" in node.keys() and node["data_input_dict"]:
291-
node["data_input_dict"] = node_dict[name]
292-
if "surfaces" in node.keys():
293-
for key, surface in node["surfaces"].items():
294-
if (
295-
"data_input_dict" in surface.keys()
296-
and surface["data_input_dict"]
297-
):
298-
node["surfaces"][key]["data_input_dict"] = surface_dict[
299-
(name, key)
300-
]
301-
node["surfaces"] = list(node["surfaces"].values())
302-
if "dates" in data.keys():
303-
self.dates = pd.to_datetime(data["dates"])
262+
if unified_data_file and PARQUET_AVAILABLE:
263+
self._load_unified_data(address, unified_data_file, nodes, data)
304264
else:
305-
# Use individual files (original behavior)
306-
for name, node in nodes.items():
307-
if "filename" in node.keys():
308-
node["data_input_dict"] = read_csv(
309-
os.path.join(address, node["filename"])
310-
)
311-
del node["filename"]
312-
if "surfaces" in node.keys():
313-
for key, surface in node["surfaces"].items():
314-
if "filename" in surface.keys():
315-
node["surfaces"][key]["data_input_dict"] = read_csv(
316-
os.path.join(address, surface["filename"])
317-
)
318-
del surface["filename"]
319-
node["surfaces"] = list(node["surfaces"].values())
320-
if "dates" in data.keys():
321-
self.dates = [to_datetime(x) for x in data["dates"]]
265+
self._load_individual_files(address, nodes, data)
322266
arcs = data.get("arcs", {})
323267
self.add_nodes(list(nodes.values()))
324268
self.add_arcs(list(arcs.values()))
@@ -327,6 +271,80 @@ def load(self, address, config_name="config.yml", overrides={}):
327271

328272
apply_patches(self)
329273

274+
def _load_unified_data(self, address, unified_data_file, nodes, data):
275+
"""Load model data from a unified parquet file.
276+
277+
Args:
278+
address (str): Path to directory containing the unified data file
279+
unified_data_file (str): Name of the unified parquet file
280+
nodes (dict): Dictionary of node configurations
281+
data (dict): Full configuration dictionary
282+
"""
283+
# Load unified data
284+
unified_data_path = os.path.join(address, unified_data_file)
285+
self.unified_data = pd.read_parquet(unified_data_path)
286+
287+
# Create a single comprehensive data dictionary for all nodes
288+
# This is much faster than individual lookups
289+
surface_data = self.unified_data.dropna(subset=["surface"]).copy()
290+
surface_data["time"] = pd.to_datetime(surface_data["time"]).dt.to_period("M")
291+
surface_dict = (
292+
surface_data.groupby(["node", "surface"])
293+
.apply(lambda x: x.set_index(["variable", "time"]).value.to_dict())
294+
.to_dict()
295+
)
296+
297+
node_data = self.unified_data.loc[self.unified_data.surface.isna()].copy()
298+
node_data["time"] = pd.to_datetime(node_data["time"])
299+
node_dict = (
300+
node_data.groupby("node")
301+
.apply(lambda x: x.set_index(["variable", "time"]).value.to_dict())
302+
.to_dict()
303+
)
304+
305+
# Assign the same comprehensive data dict to all nodes that need it
306+
for name, node in nodes.items():
307+
if "data_input_dict" in node.keys() and node["data_input_dict"]:
308+
node["data_input_dict"] = node_dict[name]
309+
if "surfaces" in node.keys():
310+
for key, surface in node["surfaces"].items():
311+
if (
312+
"data_input_dict" in surface.keys()
313+
and surface["data_input_dict"]
314+
):
315+
node["surfaces"][key]["data_input_dict"] = surface_dict[
316+
(name, key)
317+
]
318+
node["surfaces"] = list(node["surfaces"].values())
319+
if "dates" in data.keys():
320+
self.dates = pd.to_datetime(data["dates"])
321+
322+
def _load_individual_files(self, address, nodes, data):
323+
"""Load model data from individual CSV files (original behavior).
324+
325+
Args:
326+
address (str): Path to directory containing the data files
327+
nodes (dict): Dictionary of node configurations
328+
data (dict): Full configuration dictionary
329+
"""
330+
# Use individual files (original behavior)
331+
for name, node in nodes.items():
332+
if "filename" in node.keys():
333+
node["data_input_dict"] = read_csv(
334+
os.path.join(address, node["filename"])
335+
)
336+
del node["filename"]
337+
if "surfaces" in node.keys():
338+
for key, surface in node["surfaces"].items():
339+
if "filename" in surface.keys():
340+
node["surfaces"][key]["data_input_dict"] = read_csv(
341+
os.path.join(address, surface["filename"])
342+
)
343+
del surface["filename"]
344+
node["surfaces"] = list(node["surfaces"].values())
345+
if "dates" in data.keys():
346+
self.dates = [to_datetime(x) for x in data["dates"]]
347+
330348
def save(self, address, config_name="config.yml", compress=False):
331349
"""Save the model object to a yaml file and input data to csv.gz format in the
332350
directory specified.
@@ -373,7 +391,9 @@ def _save_model_config(
373391
special_args = set(["surfaces", "parent", "data_input_dict"])
374392

375393
node_props = {
376-
x: getattr(node, x) for x in set(init_args).difference(special_args)
394+
x: getattr(node, x)
395+
for x in set(init_args).difference(special_args)
396+
if hasattr(node, x)
377397
}
378398
node_props["type_"] = node.__class__.__name__
379399
node_props["node_type_override"] = (
@@ -387,6 +407,7 @@ def _save_model_config(
387407
surface_props = {
388408
x: getattr(surface, x)
389409
for x in set(surface_args).difference(special_args)
410+
if hasattr(surface, x)
390411
}
391412
surface_props["type_"] = surface.__class__.__name__
392413

@@ -542,8 +563,11 @@ def save_unified_data(
542563
config_name (str): Name of the config file
543564
compress (bool): Whether to compress (not used for parquet)
544565
"""
545-
if not PANDAS_AVAILABLE:
546-
raise ImportError("pandas is required for unified data saving")
566+
if not PARQUET_AVAILABLE:
567+
raise ImportError(
568+
"parquet support (pyarrow or fastparquet) is required "
569+
"for unified data saving"
570+
)
547571

548572
if not os.path.exists(address):
549573
os.mkdir(address)
@@ -1278,8 +1302,11 @@ def create_unified_dataframe(nodes_data, surfaces_data=None):
12781302
pd.DataFrame: Unified DataFrame with columns: node, surface, variable,
12791303
time, value
12801304
"""
1281-
if not PANDAS_AVAILABLE:
1282-
raise ImportError("pandas is required for unified DataFrame support")
1305+
if not PARQUET_AVAILABLE:
1306+
raise ImportError(
1307+
"parquet support (pyarrow or fastparquet) is required "
1308+
"for unified DataFrame support"
1309+
)
12831310

12841311
rows = []
12851312

0 commit comments

Comments
 (0)