Skip to content

Commit b3cb2f9

Browse files
committed
Removed .in-progress (and key_folder)!
1 parent 3957226 commit b3cb2f9

File tree

4 files changed

+35
-26
lines changed

4 files changed

+35
-26
lines changed

wsds/utils.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def find_first_shard(path):
3838
return None
3939

4040

41-
def list_all_columns(ds_path, shard_name=None, include_in_progress=True, key_folder=None):
41+
def list_all_columns(ds_path, shard_name=None):
4242
"""Given a dataset path, return a list of all columns.
4343
4444
If you also give a shard name it greatly speeds it up
@@ -53,13 +53,12 @@ def list_all_columns(ds_path, shard_name=None, include_in_progress=True, key_fol
5353
continue
5454
if not p.is_dir():
5555
continue
56-
is_in_progress = p.suffix == ".in-progress"
57-
if is_in_progress and not include_in_progress:
58-
continue
59-
if shard_name is None or is_in_progress:
56+
if shard_name is None:
6057
fname = find_first_shard(p)
6158
else:
6259
fname = (p / shard_name).with_suffix(".wsds")
60+
if not fname.exists():
61+
fname = find_first_shard(p)
6362
if fname and fname.exists():
6463
try:
6564
columns = get_columns(fname)
@@ -68,10 +67,8 @@ def list_all_columns(ds_path, shard_name=None, include_in_progress=True, key_fol
6867
continue
6968
for col in columns:
7069
if col == "__key__":
71-
if not is_in_progress or key_folder == fname.parent.name:
72-
# We need a subdir that has all shards but we don't wanna list all of them (that's expensive)
73-
# so instead we rely on a subdir naming convention (the .in-progress suffix) and never use these
74-
key_col.append((fname.stat().st_size, p.name, col))
70+
# List all potential __key__ columns (they should be in each shard)
71+
key_col.append((fname.stat().st_size, p.name, col))
7572
continue
7673
# seems like we should fix this during the original conversion
7774
if col in cols or col in dupes:
@@ -84,10 +81,7 @@ def list_all_columns(ds_path, shard_name=None, include_in_progress=True, key_fol
8481
else:
8582
cols[col] = (p.name, col)
8683
# use the smallest shards for __key__ (should be the fastest)
87-
if key_folder is not None:
88-
cols["__key__"] = next(col for col in key_col if key_folder == col[1])[1:]
89-
elif len(key_col) > 0:
90-
cols["__key__"] = sorted(key_col)[0][1:]
84+
cols["__key__"] = [x[1:] for x in sorted(key_col)]
9185
return dict(sorted(cols.items()))
9286

9387

wsds/ws_dataset.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,11 @@ class WSDataset:
5656
def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, key_folder: str | None = None, disable_memory_map: bool = False):
5757
self.dataset_dir = self._resolve_path(dataset_dir)
5858

59+
if include_in_progress is not True:
60+
print("NOTE: include_in_progress is deprecated and all subdirs are included by default")
61+
if key_folder is not None:
62+
print("NOTE: key_folder is deprecated and key folder is selected automatically")
63+
5964
self.index = None
6065
self.segmented = False
6166
self.disable_memory_map = disable_memory_map
@@ -71,12 +76,8 @@ def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, ke
7176
self.fields = meta['fields']
7277
else:
7378
dataset_path, shard_name = next(self.index.shards()) if self.index else ("", None)
74-
self.fields = list_all_columns(
75-
self.dataset_dir / dataset_path, shard_name, include_in_progress=include_in_progress
76-
)
77-
self.fields.update(list_all_columns(
78-
self.dataset_dir, include_in_progress=include_in_progress, key_folder=key_folder
79-
))
79+
self.fields = list_all_columns(self.dataset_dir / dataset_path, shard_name)
80+
self.fields.update(list_all_columns(self.dataset_dir))
8081
if 'computed_columns' in meta:
8182
self.computed_columns = meta['computed_columns']
8283
else:
@@ -254,14 +255,19 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None, shard
254255
# __key__ exists in all shards
255256
needed_special_columns.append(col)
256257
continue
257-
subdir, field = self.fields[col]
258+
value = self.fields[col]
259+
# FIXME: figure out a way to handle all candidates for __key__
260+
if isinstance(value[0], str):
261+
subdir, field = value
262+
else:
263+
subdir, field = value[0]
258264
assert col == field, "renamed fields are not supported in SQL queries yet"
259265
subdirs[subdir].append(field)
260266
exprs.append(expr)
261267

262268
# If only __key__ is in the query, we need to load shards from at least one subdir
263269
key_value = self.fields["__key__"]
264-
key_subdir = key_value[0]
270+
key_subdir = key_value[0] if isinstance(key_value[0], str) else key_value[0][0]
265271
if needed_special_columns:
266272
if subdirs:
267273
key_subdir = list(subdirs.keys())[0]
@@ -439,7 +445,8 @@ def get_shard_path(self, subdir, shard_name):
439445
return (Path(dir) / shard_name).with_suffix(".wsds")
440446

441447
def _register_wsds_links(self):
442-
for subdir, _ in self.fields.values():
448+
for value in self.fields.values():
449+
subdir = value[0] if isinstance(value[0], str) else value[0][0]
443450
if subdir.endswith(".wsds-link"):
444451
spec = json.loads((self.dataset_dir / subdir).read_text())
445452
self.computed_columns[subdir] = spec
@@ -481,8 +488,16 @@ def get_shard(self, subdir, shard_name):
481488
return shard
482489

483490
def get_sample(self, shard_name, field, offset):
484-
subdir, column = self.fields[field]
485-
return self.get_shard(subdir, shard_name).get_sample(column, offset)
491+
value = self.fields[field]
492+
alternatives = [value] if isinstance(value[0], str) else value
493+
last_err = None
494+
for subdir, column in alternatives:
495+
try:
496+
return self.get_shard(subdir, shard_name).get_sample(column, offset)
497+
except WSShardMissingError as e:
498+
last_err = e
499+
continue
500+
raise last_err
486501

487502
def parse_key(self, key):
488503
if self.segmented:

wsds/ws_sample.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@ def __repr__(self, repr=repr):
8989
if k in self.overrides:
9090
subdir = "__overrides__"
9191
elif k in self.dataset.fields:
92-
subdir, _ = self.dataset.fields[k]
92+
value = self.dataset.fields[k]
93+
subdir = value[0] if isinstance(value[0], str) else value[0][0]
9394
else:
9495
subdir = "__unknown__"
9596
if subdir not in subdir_columns:

wsds/ws_tools.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,6 @@ def init_split(
448448
source_dataset: Path | None = None,
449449
vad_column: str | None = None,
450450
num_workers: int = 64,
451-
include_in_progress: bool = False,
452451
index_path: str = ".",
453452
):
454453
"""Initialize a new dataset, from scratch or from a segmentation of an existing one."""

0 commit comments

Comments
 (0)