2424 pyarrow_available = importlib .util .find_spec ("pyarrow" ) is not None
2525 fastparquet_available = importlib .util .find_spec ("fastparquet" ) is not None
2626
27- PANDAS_AVAILABLE = pyarrow_available or fastparquet_available
27+ PARQUET_AVAILABLE = pyarrow_available or fastparquet_available
2828except ImportError :
29- PANDAS_AVAILABLE = False
29+ PARQUET_AVAILABLE = False
3030
3131from wsimod .arcs import arcs as arcs_mod
3232from wsimod .core import constants
@@ -259,66 +259,10 @@ def load(self, address, config_name="config.yml", overrides={}):
259259
260260 # Check if using unified data file
261261 unified_data_file = data .get ("unified_data_file" )
262- if unified_data_file and PANDAS_AVAILABLE :
263-
264- # Load unified data
265- unified_data_path = os .path .join (address , unified_data_file )
266- self .unified_data = pd .read_parquet (unified_data_path )
267-
268- # Create a single comprehensive data dictionary for all nodes
269- # This is much faster than individual lookups
270- surface_data = self .unified_data .dropna (subset = ["surface" ]).copy ()
271- surface_data ["time" ] = pd .to_datetime (surface_data ["time" ]).dt .to_period (
272- "M"
273- )
274- surface_dict = (
275- surface_data .groupby (["node" , "surface" ])
276- .apply (lambda x : x .set_index (["variable" , "time" ]).value .to_dict ())
277- .to_dict ()
278- )
279-
280- node_data = self .unified_data .loc [self .unified_data .surface .isna ()].copy ()
281- node_data ["time" ] = pd .to_datetime (node_data ["time" ])
282- node_dict = (
283- node_data .groupby ("node" )
284- .apply (lambda x : x .set_index (["variable" , "time" ]).value .to_dict ())
285- .to_dict ()
286- )
287-
288- # Assign the same comprehensive data dict to all nodes that need it
289- for name , node in nodes .items ():
290- if "data_input_dict" in node .keys () and node ["data_input_dict" ]:
291- node ["data_input_dict" ] = node_dict [name ]
292- if "surfaces" in node .keys ():
293- for key , surface in node ["surfaces" ].items ():
294- if (
295- "data_input_dict" in surface .keys ()
296- and surface ["data_input_dict" ]
297- ):
298- node ["surfaces" ][key ]["data_input_dict" ] = surface_dict [
299- (name , key )
300- ]
301- node ["surfaces" ] = list (node ["surfaces" ].values ())
302- if "dates" in data .keys ():
303- self .dates = pd .to_datetime (data ["dates" ])
262+ if unified_data_file and PARQUET_AVAILABLE :
263+ self ._load_unified_data (address , unified_data_file , nodes , data )
304264 else :
305- # Use individual files (original behavior)
306- for name , node in nodes .items ():
307- if "filename" in node .keys ():
308- node ["data_input_dict" ] = read_csv (
309- os .path .join (address , node ["filename" ])
310- )
311- del node ["filename" ]
312- if "surfaces" in node .keys ():
313- for key , surface in node ["surfaces" ].items ():
314- if "filename" in surface .keys ():
315- node ["surfaces" ][key ]["data_input_dict" ] = read_csv (
316- os .path .join (address , surface ["filename" ])
317- )
318- del surface ["filename" ]
319- node ["surfaces" ] = list (node ["surfaces" ].values ())
320- if "dates" in data .keys ():
321- self .dates = [to_datetime (x ) for x in data ["dates" ]]
265+ self ._load_individual_files (address , nodes , data )
322266 arcs = data .get ("arcs" , {})
323267 self .add_nodes (list (nodes .values ()))
324268 self .add_arcs (list (arcs .values ()))
@@ -327,6 +271,80 @@ def load(self, address, config_name="config.yml", overrides={}):
327271
328272 apply_patches (self )
329273
274+ def _load_unified_data (self , address , unified_data_file , nodes , data ):
275+ """Load model data from a unified parquet file.
276+
277+ Args:
278+ address (str): Path to directory containing the unified data file
279+ unified_data_file (str): Name of the unified parquet file
280+ nodes (dict): Dictionary of node configurations
281+ data (dict): Full configuration dictionary
282+ """
283+ # Load unified data
284+ unified_data_path = os .path .join (address , unified_data_file )
285+ self .unified_data = pd .read_parquet (unified_data_path )
286+
287+ # Create a single comprehensive data dictionary for all nodes
288+ # This is much faster than individual lookups
289+ surface_data = self .unified_data .dropna (subset = ["surface" ]).copy ()
290+ surface_data ["time" ] = pd .to_datetime (surface_data ["time" ]).dt .to_period ("M" )
291+ surface_dict = (
292+ surface_data .groupby (["node" , "surface" ])
293+ .apply (lambda x : x .set_index (["variable" , "time" ]).value .to_dict ())
294+ .to_dict ()
295+ )
296+
297+ node_data = self .unified_data .loc [self .unified_data .surface .isna ()].copy ()
298+ node_data ["time" ] = pd .to_datetime (node_data ["time" ])
299+ node_dict = (
300+ node_data .groupby ("node" )
301+ .apply (lambda x : x .set_index (["variable" , "time" ]).value .to_dict ())
302+ .to_dict ()
303+ )
304+
305+ # Assign the same comprehensive data dict to all nodes that need it
306+ for name , node in nodes .items ():
307+ if "data_input_dict" in node .keys () and node ["data_input_dict" ]:
308+ node ["data_input_dict" ] = node_dict [name ]
309+ if "surfaces" in node .keys ():
310+ for key , surface in node ["surfaces" ].items ():
311+ if (
312+ "data_input_dict" in surface .keys ()
313+ and surface ["data_input_dict" ]
314+ ):
315+ node ["surfaces" ][key ]["data_input_dict" ] = surface_dict [
316+ (name , key )
317+ ]
318+ node ["surfaces" ] = list (node ["surfaces" ].values ())
319+ if "dates" in data .keys ():
320+ self .dates = pd .to_datetime (data ["dates" ])
321+
322+ def _load_individual_files (self , address , nodes , data ):
323+ """Load model data from individual CSV files (original behavior).
324+
325+ Args:
326+ address (str): Path to directory containing the data files
327+ nodes (dict): Dictionary of node configurations
328+ data (dict): Full configuration dictionary
329+ """
330+ # Use individual files (original behavior)
331+ for name , node in nodes .items ():
332+ if "filename" in node .keys ():
333+ node ["data_input_dict" ] = read_csv (
334+ os .path .join (address , node ["filename" ])
335+ )
336+ del node ["filename" ]
337+ if "surfaces" in node .keys ():
338+ for key , surface in node ["surfaces" ].items ():
339+ if "filename" in surface .keys ():
340+ node ["surfaces" ][key ]["data_input_dict" ] = read_csv (
341+ os .path .join (address , surface ["filename" ])
342+ )
343+ del surface ["filename" ]
344+ node ["surfaces" ] = list (node ["surfaces" ].values ())
345+ if "dates" in data .keys ():
346+ self .dates = [to_datetime (x ) for x in data ["dates" ]]
347+
330348 def save (self , address , config_name = "config.yml" , compress = False ):
331349 """Save the model object to a yaml file and input data to csv.gz format in the
332350 directory specified.
@@ -373,7 +391,9 @@ def _save_model_config(
373391 special_args = set (["surfaces" , "parent" , "data_input_dict" ])
374392
375393 node_props = {
376- x : getattr (node , x ) for x in set (init_args ).difference (special_args )
394+ x : getattr (node , x )
395+ for x in set (init_args ).difference (special_args )
396+ if hasattr (node , x )
377397 }
378398 node_props ["type_" ] = node .__class__ .__name__
379399 node_props ["node_type_override" ] = (
@@ -387,6 +407,7 @@ def _save_model_config(
387407 surface_props = {
388408 x : getattr (surface , x )
389409 for x in set (surface_args ).difference (special_args )
410+ if hasattr (surface , x )
390411 }
391412 surface_props ["type_" ] = surface .__class__ .__name__
392413
@@ -542,8 +563,11 @@ def save_unified_data(
542563 config_name (str): Name of the config file
543564 compress (bool): Whether to compress (not used for parquet)
544565 """
545- if not PANDAS_AVAILABLE :
546- raise ImportError ("pandas is required for unified data saving" )
566+ if not PARQUET_AVAILABLE :
567+ raise ImportError (
568+ "parquet support (pyarrow or fastparquet) is required "
569+ "for unified data saving"
570+ )
547571
548572 if not os .path .exists (address ):
549573 os .mkdir (address )
@@ -1278,8 +1302,11 @@ def create_unified_dataframe(nodes_data, surfaces_data=None):
12781302 pd.DataFrame: Unified DataFrame with columns: node, surface, variable,
12791303 time, value
12801304 """
1281- if not PANDAS_AVAILABLE :
1282- raise ImportError ("pandas is required for unified DataFrame support" )
1305+ if not PARQUET_AVAILABLE :
1306+ raise ImportError (
1307+ "parquet support (pyarrow or fastparquet) is required "
1308+ "for unified DataFrame support"
1309+ )
12831310
12841311 rows = []
12851312
0 commit comments