Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
__pycache__
2 changes: 2 additions & 0 deletions wsds/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

"""

from .utils import WSShardMissingError
from .ws_dataset import WSDataset
from .ws_sample import WSSample
from .ws_shard import WSSourceAudioShard
Expand All @@ -17,6 +18,7 @@
__all__ = [
WSDataset,
WSSample,
WSShardMissingError,
WSSourceAudioShard,
AtomicFile,
SampleFormatChanged,
Expand Down
86 changes: 74 additions & 12 deletions wsds/ws_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,9 @@ class WSDataset:
# FIXME: this should be overridable with metadata in index.sqlite3
_audio_file_keys = ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma", "opus", "audio"]

def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, key_folder: str | None = None, disable_memory_map: bool = False):
def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, key_folder: str | None = None, disable_memory_map: bool = False, skip_links: bool = False):
self.dataset_dir = self._resolve_path(dataset_dir)
self.skip_links = skip_links

self.index = None
self.segmented = False
Expand Down Expand Up @@ -79,7 +80,8 @@ def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, ke
self._open_shards = {}
self._linked_datasets = {}

self._register_wsds_links()
if not skip_links:
self._register_wsds_links()

def enable_filter(self, filter_name: str, filter_df: pl.DataFrame):
"""
Expand All @@ -94,11 +96,6 @@ def enable_filter(self, filter_name: str, filter_df: pl.DataFrame):

self._filter_dfs[filter_name] = filter_df

rows_satisfying_filter = filter_df.sum().item()
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rashishhume do you maybe know what's going on here?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted this since, it caused a lot of log output at the start of my training jobs. Maybe it could make more sense to log this stuff higher up.

print(
f"Filter enabled on dataset {repr(self)}. Rows satisfying the filter: {rows_satisfying_filter} / {len(filter_df)}"
)

#
# Accessing samples randomly and sequentially
#
Expand Down Expand Up @@ -187,7 +184,7 @@ def sequential_from(self, sample, max_N=None):
if self._filter_dfs is not None:
# We need to know the global shard offset to know what filter values to use for the sample
shard_global_offset = self.index.query(
"SELECT global_offset FROM shards WHERE shard = ?", shard_name
"SELECT global_offset FROM shards WHERE shard = ?", shard_name[1]
).fetchone()[0]

while i < max_N:
Expand Down Expand Up @@ -247,7 +244,8 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None):
continue
subdir, field = self.fields[col]
assert col == field, "renamed fields are not supported in SQL queries yet"
subdirs[subdir].append(field)
if field not in subdirs[subdir]:
subdirs[subdir].append(field)
exprs.append(expr)

# If only __key__ is in the query, we need to load shards from at least one subdir
Expand All @@ -265,12 +263,23 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None):
if shard_subsample != 1:
shard_list = rng.sample(shard_list, int(len(shard_list) * shard_subsample))

# TODO: Not sure if we want to drop the columns. I think previously we
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could apply the renaming on SQL as well which would make it consistent with the non-SQL API.

I think most of the issues with duplicate fields were fixed recently when we added the select in this line:

                    df = scan_ipc(shard_path, glob=False).select(fields)

Do you remember which duplicate columns are causing you headaches?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was language_whisper.txt I think, since this is contained in all the shards with Whisper transcripts.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, interesting. We need to fix this then.

# did some renaming, but that might also be confusing. Maybe we can put
# this behind a flag.
cols_seen: set[str] = set()
deduplicated_subdirs: dict[str, list[str]] = {}
for subdir, fields in subdirs.items():
unique_fields = [f for f in fields if f not in cols_seen]
if unique_fields:
deduplicated_subdirs[subdir] = unique_fields
cols_seen.update(unique_fields)

row_merge = []
subdir_samples = {}
missing = defaultdict(list)
for shard in shard_list:
col_merge = []
for subdir, fields in subdirs.items():
for subdir, fields in deduplicated_subdirs.items():
shard_path = self.get_shard_path(subdir, shard)
if shard_path.exists():
df = scan_ipc(shard_path, glob=False).select(fields)
Expand Down Expand Up @@ -300,7 +309,54 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None):
f"No usable shards found (columns: {', '.join(subdirs)}) for dataset in: {str(self.dataset_dir)}"
)

return exprs, pl.concat(row_merge).select(exprs)
def _common_dtype(col_name: str, a: pl.DataType, b: pl.DataType) -> pl.DataType:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My guess would be that this is about some shards having a null type because all the samples were None for a column?

Would concat with how="vertical_relaxed" help in this situation? (this would let Polars handle the coercion, hopefully in a sensible way)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that some metrics have shards stored as both float16 and float64.

Copy link
Copy Markdown
Author

@streichgeorg streichgeorg Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll try vertical_relaxed (I remember some polars merge mode not handling the issue I was facing, not 100% sure that was vertical_relaxed)

if a == b:
return a
if a == pl.Null:
return b
if b == pl.Null:
return a

if not (a.is_numeric() and b.is_numeric()):
raise TypeError(f"Cannot reconcile dtypes for column {col_name!r}: {a} vs {b}.")
# Make numeric dtypes consistent across shards so diagonal concat doesn't fail.
if a.is_float() or b.is_float():
return pl.Float64
if a.is_unsigned_integer() and b.is_unsigned_integer():
return pl.UInt64
if a.is_signed_integer() and b.is_signed_integer():
return pl.Int64

if a == pl.UInt64 or b == pl.UInt64:
raise TypeError(f"Cannot safely coerce column {col_name!r}: mixing UInt64 with signed integer ({a} vs {b})")

return pl.Int64


def _cast_lazyframe_to_schema(lf: pl.LazyFrame, target_dtypes: dict[str, pl.DataType]) -> pl.LazyFrame:
schema = lf.collect_schema()
exprs = []
for col_name, target_dtype in target_dtypes.items():
if col_name not in schema:
continue
if schema[col_name] == target_dtype:
continue
exprs.append(pl.col(col_name).cast(target_dtype).alias(col_name))
if not exprs:
return lf
return lf.with_columns(exprs)

target_dtypes: dict[str, pl.DataType] = {}
for lf in row_merge:
for col_name, dtype in lf.collect_schema().items():
target_dtypes[col_name] = (
_common_dtype(col_name, target_dtypes[col_name], dtype)
if col_name in target_dtypes else
dtype
)
row_merge = [_cast_lazyframe_to_schema(lf, target_dtypes) for lf in row_merge]

return exprs, pl.concat(row_merge, how="diagonal").select(exprs)

def _check_for_subsampling(self, shard_subsample):
if shard_subsample is None:
Expand Down Expand Up @@ -423,7 +479,13 @@ def get_linked_shard(self, link, shard_name):
loader_class = link["loader"]
if isinstance(loader_class, list):
loader_mod, loader = loader_class
loader_module = importlib.import_module(loader_mod)

try:
loader_module = importlib.import_module(loader_mod)
except ImportError:
loader_mod = loader_mod.replace("hume_wsds", "wsds")
loader_module = importlib.import_module(loader_mod)

loader_class = getattr(loader_module, loader)

return loader_class.from_link(
Expand Down
11 changes: 10 additions & 1 deletion wsds/ws_sample.py
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This unfortunately makes the test as expensive as loading the data with .get('column').

How do you use this downstream?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have some datasets where certain transcripts are only computed partially. I could also catch the exception in my code, but I felt it is a bit counter intuitive if __contains__ succeeds, but .get() fails.

Copy link
Copy Markdown
Author

@streichgeorg streichgeorg Jan 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Feels like in most cases you want to do something like

if "col" in sample:
    # Do something with sample["col"]
else:
    # Do something else

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could do the more thorough check if include_partial_shards is set on the dataset?

Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,16 @@ def get(self, field: str, default=None):
return self[field] if field in self else default

def __contains__(self, field):
return field in self.overrides or field in self.dataset.fields.keys()
if field in self.overrides:
return True
if field not in self.dataset.fields:
return False
# Field exists in schema, but shard might be missing (e.g. .in-progress directories)
try:
self[field]
return True
except WSShardMissingError:
return False

def __repr_field__(self, field, repr=repr):
try:
Expand Down
2 changes: 2 additions & 0 deletions wsds/ws_shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ def get_sample(self, column: str, offset: int) -> typing.Any:
if self._data.schema.get_field_index(column) == -1:
raise KeyError(f"column {column} not found in shard {self.fname}")
data = self._data[column][j]
if not data.is_valid:
return None
try:
# FIXME: implement proper encoders and decoders
if column.endswith("npy"):
Expand Down
51 changes: 30 additions & 21 deletions wsds/ws_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,6 +395,7 @@ def init(
source_dataset: Path | None = None,
vad_column: str | None = None,
num_workers: int = 32,
key_folder: str | None = None,
):
"""Initialize a new dataset, from scratch or from a segmentation of an existing one."""
import multiprocessing
Expand All @@ -411,14 +412,16 @@ def init(
else:
source_dataset = new_dataset

ds = WSDataset(source_dataset)
shard_extractor = functools.partial(extract_index_for_shard, source_dataset, vad_column=vad_column)
ds = WSDataset(source_dataset, key_folder=key_folder)
shard_extractor = functools.partial(extract_index_for_shard, source_dataset, vad_column=vad_column, key_folder=key_folder)
all_shards = ds.get_shard_list(ignore_index = True)

with AtomicFile(new_dataset / "index.sqlite3") as fname:
with WSDSIndexWriter(fname) as index:
with multiprocessing.Pool(num_workers) as p:
for r in progress_bar(p.imap_unordered(shard_extractor, all_shards), total=len(all_shards)):
if r["n_samples"] == 0:
continue
r["dataset_path"] = ""
try:
index.append(r)
Expand Down Expand Up @@ -501,34 +504,40 @@ def init_split(
new_fields['audio'] = ("audio.wsds-computed", "audio")
index.append_metadata({"fields": new_fields})

def extract_index_for_shard(dataset, shard, vad_column=None):
def extract_index_for_shard(dataset, shard, vad_column=None, key_folder=None):
import pyarrow as pa

from . import WSDataset

ds = WSDataset(dataset)
ds = WSDataset(dataset, key_folder=key_folder)
index = []
i = 0

for s in ds.iter_shard(shard):
key = s["__key__"]
try:
for s in ds.iter_shard(shard):
key = s["__key__"]

if not vad_column:
n = 1
speech_duration = -1
else:
vad = s[vad_column]
n = len(vad)
speech_duration = 0
if vad.size > 0:
speech_duration = float((vad[:, -1] - vad[:, -2]).sum()) # tend - tstart
if not vad_column:
n = 1
speech_duration = -1
else:
vad = s[vad_column]
n = len(vad)
speech_duration = 0
if vad.size > 0:
speech_duration = float((vad[:, -1] - vad[:, -2]).sum()) # tend - tstart

audio_duration = s.get('load_duration') or s.get('est_duration') or -1

audio_duration = s['load_duration'] or s['est_duration'] or -1
if (
n > 0
): # in derived datasets, skip files with no vad segments (they won't have samples and will never appear as keys)
index.append((key, i, audio_duration, speech_duration))

if (
n > 0
): # in derived datasets, skip files with no vad segments (they won't have samples and will never appear as keys)
index.append((key, i, audio_duration, speech_duration))
i += n
except pa.lib.ArrowInvalid as e:
print(f"Skipping corrupt shard {shard[1]}: {e}")

i += n
return {
"shard_name": shard[1],
"index": index,
Expand Down