@@ -56,6 +56,11 @@ class WSDataset:
5656 def __init__ (self , dataset_dir : str | Path , include_in_progress : bool = True , key_folder : str | None = None , disable_memory_map : bool = False ):
5757 self .dataset_dir = self ._resolve_path (dataset_dir )
5858
59+ if include_in_progress is not True :
60+ print ("NOTE: include_in_progress is deprecated and all subdirs are included by default" )
61+ if key_folder is not None :
62+ print ("NOTE: key_folder is deprecated and key folder is selected automatically" )
63+
5964 self .index = None
6065 self .segmented = False
6166 self .disable_memory_map = disable_memory_map
@@ -71,12 +76,8 @@ def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, ke
7176 self .fields = meta ['fields' ]
7277 else :
7378 dataset_path , shard_name = next (self .index .shards ()) if self .index else ("" , None )
74- self .fields = list_all_columns (
75- self .dataset_dir / dataset_path , shard_name , include_in_progress = include_in_progress
76- )
77- self .fields .update (list_all_columns (
78- self .dataset_dir , include_in_progress = include_in_progress , key_folder = key_folder
79- ))
79+ self .fields = list_all_columns (self .dataset_dir / dataset_path , shard_name )
80+ self .fields .update (list_all_columns (self .dataset_dir ))
8081 if 'computed_columns' in meta :
8182 self .computed_columns = meta ['computed_columns' ]
8283 else :
@@ -254,14 +255,19 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard
254255 # __key__ exists in all shards
255256 needed_special_columns .append (col )
256257 continue
257- subdir , field = self .fields [col ]
258+ value = self .fields [col ]
259+ # FIXME: figure out a way to handle all candidates for __key__
260+ if isinstance (value [0 ], str ):
261+ subdir , field = value
262+ else :
263+ subdir , field = value [0 ]
258264 assert col == field , "renamed fields are not supported in SQL queries yet"
259265 subdirs [subdir ].append (field )
260266 exprs .append (expr )
261267
262268 # If only __key__ is in the query, we need to load shards from at least one subdir
263269 key_value = self .fields ["__key__" ]
264- key_subdir = key_value [0 ]
270+ key_subdir = key_value [0 ] if isinstance ( key_value [ 0 ], str ) else key_value [ 0 ][ 0 ]
265271 if needed_special_columns :
266272 if subdirs :
267273 key_subdir = list (subdirs .keys ())[0 ]
@@ -439,7 +445,8 @@ def get_shard_path(self, subdir, shard_name):
439445 return (Path (dir ) / shard_name ).with_suffix (".wsds" )
440446
441447 def _register_wsds_links (self ):
442- for subdir , _ in self .fields .values ():
448+ for value in self .fields .values ():
449+ subdir = value [0 ] if isinstance (value [0 ], str ) else value [0 ][0 ]
443450 if subdir .endswith (".wsds-link" ):
444451 spec = json .loads ((self .dataset_dir / subdir ).read_text ())
445452 self .computed_columns [subdir ] = spec
@@ -481,8 +488,16 @@ def get_shard(self, subdir, shard_name):
481488 return shard
482489
483490 def get_sample (self , shard_name , field , offset ):
484- subdir , column = self .fields [field ]
485- return self .get_shard (subdir , shard_name ).get_sample (column , offset )
491+ value = self .fields [field ]
492+ alternatives = [value ] if isinstance (value [0 ], str ) else value
493+ last_err = None
494+ for subdir , column in alternatives :
495+ try :
496+ return self .get_shard (subdir , shard_name ).get_sample (column , offset )
497+ except WSShardMissingError as e :
498+ last_err = e
499+ continue
500+ raise last_err
486501
487502 def parse_key (self , key ):
488503 if self .segmented :
0 commit comments