Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
fastprogress
fire
numpy
polars
polars>=1.36.1
pyarrow>=20
torch
torchaudio
Expand Down
2 changes: 1 addition & 1 deletion tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
import doctest
from . import ws_dataset, ws_shard, ws_sink
from wsds import ws_dataset, ws_shard, ws_sink

def load_tests(loader, tests, ignore):
tests.addTests(doctest.DocTestSuite(ws_dataset))
Expand Down
31 changes: 7 additions & 24 deletions wsds/ws_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,29 +138,17 @@ def __getitem__(self, key_or_index: str | int):
# Figure out the shard name, local offset (wrt shard) and global offset for the given key or index
shard_name, local_offset, global_offset = None, None, None

if self.index.has_dataset_path:
dataset_path = 's.dataset_path'
else:
dataset_path = "''"

if isinstance(key_or_index, int):
r = self.index.query(
f"SELECT s.shard, global_offset, {dataset_path} FROM shards AS s WHERE s.global_offset <= ? ORDER BY s.global_offset DESC LIMIT 1",
key_or_index,
).fetchone()
r = self.index.get_shard_by_global_index(key_or_index)
if not r:
return None

shard_name, shard_global_offset, dataset_path = r
global_offset = key_or_index
local_offset = global_offset - shard_global_offset
elif isinstance(key_or_index, str):
# FIXME: push `parse_key` to the index class
file_name, offset_of_key_wrt_file = self.parse_key(key_or_index)
r = self.index.query(
f"SELECT s.shard, s.global_offset, f.offset, {dataset_path} FROM files AS f, shards AS s WHERE f.name = ? AND s.shard_id == f.shard_id",
file_name,
).fetchone()
r = self.index.get_shard_by_file_name(file_name)
if not r:
return None

Expand All @@ -186,9 +174,7 @@ def sequential_from(self, sample, max_N=None):
shard_global_offset = None
if self._filter_dfs is not None:
# We need to know the global shard offset to know what filter values to use for the sample
shard_global_offset = self.index.query(
"SELECT global_offset FROM shards WHERE shard = ?", shard_name
).fetchone()[0]
shard_global_offset = self.index.get_shard_global_offset(shard_name)

while i < max_N:
sample = WSSample(self, shard_name, i)
Expand All @@ -209,10 +195,10 @@ def sequential_from(self, sample, max_N=None):
def _shard_n_samples(self, shard_name: (str, str)) -> int:
if not self.index:
return sys.maxsize
r = self.index.query("SELECT n_samples FROM shards WHERE shard = ?", shard_name[1]).fetchone()
if r is None:
n_samples = self.index.get_shard_n_samples(shard_name)
if n_samples is None:
raise IndexError(f"Shard not found: {shard_name}")
return r[0]
return n_samples

def iter_shard(self, shard_name):
dataset_path, shard_name = shard_name
Expand Down Expand Up @@ -267,10 +253,7 @@ def _parse_sql_queries_polars(self, *queries):
subdir_samples[subdir] = df.clear().collect()
else:
# create a fake dataframe with all NULL rows and matching schema
if self.index.has_dataset_path:
n_samples, = self.index.query("SELECT n_samples FROM shards WHERE shards.dataset_path = ? AND shards.shard = ?", *shard).fetchone()
else:
n_samples, = self.index.query("SELECT n_samples FROM shards WHERE shards.shard = ?", shard[1]).fetchone()
n_samples = self.index.get_shard_n_samples(shard)
df = pl.defer(
lambda subdir=subdir, n_samples=n_samples: subdir_samples[subdir].clear(n=n_samples),
schema=lambda subdir=subdir: subdir_samples[subdir].schema,
Expand Down
202 changes: 187 additions & 15 deletions wsds/ws_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,16 @@ def __exit__(self, exc_type, exc_value, traceback):


class WSIndex:
"""SQLite-based index for fast random access to samples in a wsds dataset.

The index stores:
- `shards` table: shard names, sample counts, and global offsets
- `files` table: source file names, their shard, offset within shard, and duration info
- `metadata` table: JSON-encoded dataset metadata (e.g., segmented flag, fields)

This enables O(1) lookups by global sample index or by file name.
"""

def __init__(self, fname: str):
self.fname = fname
if not Path(fname).exists():
Expand All @@ -102,31 +112,196 @@ def __init__(self, fname: str):
self.conn = sqlite3.connect(f"file:{fname}?immutable=1,ro=True", uri=True)
self.has_dataset_path = self.conn.execute("SELECT COUNT(*) FROM pragma_table_info('shards') WHERE name='dataset_path'").fetchone()[0]

#
# Aggregate properties
#

@functools.cached_property
def n_shards(self):
def n_shards(self) -> int:
"""Total number of shards in the dataset."""
return self.conn.execute("SELECT COUNT(*) FROM shards;").fetchone()[0]

@functools.cached_property
def n_files(self):
def n_files(self) -> int:
"""Total number of source files in the dataset."""
return self.conn.execute("SELECT COUNT(*) FROM files;").fetchone()[0]

@functools.cached_property
def n_samples(self):
def n_samples(self) -> int:
"""Total number of samples across all shards."""
return self.conn.execute("SELECT SUM(n_samples) FROM shards;").fetchone()[0]

@functools.cached_property
def audio_duration(self):
def audio_duration(self) -> float:
"""Total audio duration in seconds across all files."""
return self.conn.execute("SELECT SUM(audio_duration) FROM files;").fetchone()[0]

@functools.cached_property
def speech_duration(self):
def speech_duration(self) -> float:
"""Total speech duration in seconds (for segmented datasets)."""
return self.conn.execute("SELECT SUM(speech_duration) FROM files;").fetchone()[0]

@functools.cached_property
def metadata(self) -> dict:
"""Dataset metadata dictionary (merged from all metadata rows)."""
metadata = {}
try:
for (metadata_chunk,) in self.conn.execute("SELECT value FROM metadata;"):
metadata.update(json.loads(metadata_chunk))
except sqlite3.OperationalError as err:
if err.args[0] != "no such table: metadata":
raise
return metadata

#
# Shard iteration
#

def shards(self):
"""Iterate over all shards as (dataset_path, shard_name) tuples.

Yields tuples in the order shards were added to the index.
"""
dataset_path = 'dataset_path' if self.has_dataset_path else "''"
return self.conn.execute(f"SELECT {dataset_path}, shard FROM shards ORDER BY rowid;")

#
# Shard lookups
#

def get_shard_by_global_index(self, global_index: int) -> tuple[str, int, str] | None:
"""Find the shard containing a given global sample index.

Args:
global_index: The global sample index (0-based across the entire dataset).

Returns:
Tuple of (shard_name, shard_global_offset, dataset_path) or None if not found.
The local offset within the shard is: global_index - shard_global_offset.
"""
dataset_path = 's.dataset_path' if self.has_dataset_path else "''"
r = self.conn.execute(
f"SELECT s.shard, global_offset, {dataset_path} FROM shards AS s "
"WHERE s.global_offset <= ? ORDER BY s.global_offset DESC LIMIT 1",
(global_index,),
).fetchone()
return r

def get_shard_by_file_name(self, file_name: str) -> tuple[str, int, int, str] | None:
"""Find the shard containing a given source file.

Args:
file_name: The source file name (without segment suffix).

Returns:
Tuple of (shard_name, shard_global_offset, file_offset_in_shard, dataset_path)
or None if not found.
"""
dataset_path = 's.dataset_path' if self.has_dataset_path else "''"
r = self.conn.execute(
f"SELECT s.shard, s.global_offset, f.offset, {dataset_path} "
"FROM files AS f, shards AS s WHERE f.name = ? AND s.shard_id == f.shard_id",
(file_name,),
).fetchone()
return r

def get_shard_global_offset(self, shard_name: str) -> int | None:
"""Get the global sample offset for a shard.

Args:
shard_name: The shard name (without .wsds extension).

Returns:
The global offset (first sample index in this shard), or None if not found.
"""
r = self.conn.execute(
"SELECT global_offset FROM shards WHERE shard = ?",
(shard_name,),
).fetchone()
return r[0] if r else None

def get_shard_n_samples(self, shard: tuple[str, str]) -> int | None:
"""Get the number of samples in a shard.

Args:
shard: Tuple of (dataset_path, shard_name).

Returns:
The number of samples in the shard, or None if not found.
"""
dataset_path, shard_name = shard
if self.has_dataset_path and dataset_path:
r = self.conn.execute(
"SELECT n_samples FROM shards WHERE dataset_path = ? AND shard = ?",
(dataset_path, shard_name),
).fetchone()
else:
r = self.conn.execute(
"SELECT n_samples FROM shards WHERE shard = ?",
(shard_name,),
).fetchone()
return r[0] if r else None

def get_shard_info(self, shard: tuple[str, str]) -> tuple[int, int] | None:
"""Get shard_id and n_samples for a shard.

Args:
shard: Tuple of (dataset_path, shard_name).

Returns:
Tuple of (n_samples, shard_id) or None if not found.
"""
dataset_path, shard_name = shard
if self.has_dataset_path and dataset_path:
r = self.conn.execute(
"SELECT n_samples, shard_id FROM shards WHERE dataset_path = ? AND shard = ?",
(dataset_path, shard_name),
).fetchone()
else:
r = self.conn.execute(
"SELECT n_samples, shard_id FROM shards WHERE shard = ?",
(shard_name,),
).fetchone()
return r

#
# File lookups
#

def get_files_for_shard(self, shard_id: int) -> list[tuple[str, int]]:
"""Get all files in a shard with their offsets.

Args:
shard_id: The internal shard ID.

Returns:
List of (file_name, offset) tuples.
"""
return self.conn.execute(
"SELECT name, offset FROM files WHERE shard_id = ?",
(shard_id,),
).fetchall()

def iter_files(self):
"""Iterate over all files with their shard and offset info.

Yields tuples of (file_name, shard_name, offset) ordered by file name.
"""
return self.conn.execute(
"SELECT name, s.shard, offset FROM files AS f, shards AS s "
"WHERE s.shard_id == f.shard_id ORDER BY name, s.shard, offset;"
)

#
# DataFrame export
#

def dataframe(self):
"""Export the index as a Polars DataFrame.

Returns:
DataFrame with columns: name, audio_duration, speech_duration, shard, n_samples.
"""
import polars as pl
df = pl.read_database_uri("""
SELECT f.name, audio_duration, speech_duration, s.shard, s.n_samples
Expand All @@ -135,18 +310,15 @@ def dataframe(self):
)
return df

@functools.cached_property
def metadata(self):
metadata = {}
try:
for (metadata_chunk,) in self.conn.execute("SELECT value FROM metadata;"):
metadata.update(json.loads(metadata_chunk))
except sqlite3.OperationalError as err:
if err.args[0] != "no such table: metadata":
raise
return metadata
#
# Low-level query access (prefer using specific methods above)
#

def query(self, query, *args):
"""Execute a raw SQL query on the index database.

Prefer using the specific lookup methods above when possible.
"""
return self.conn.execute(query, args)

def __repr__(self):
Expand Down
4 changes: 1 addition & 3 deletions wsds/ws_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ def dump_index(source_dataset: Path):
ds = WSDataset(source_dataset)

try:
for sample in ds.index.query(
"SELECT name,s.shard,offset FROM files AS f, shards AS s WHERE s.shard_id == f.shard_id ORDER BY name,s.shard,offset;"
):
for sample in ds.index.iter_files():
print(*sample)
except BrokenPipeError:
pass
Expand Down