-
Notifications
You must be signed in to change notification settings - Fork 1
Fix dataloading #15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix dataloading #15
Changes from all commits
0376b17
1ff0139
444bcda
741ed67
2cc1d92
add5a59
6110877
46fb136
5b963ea
d34c682
ec9d906
a087da6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| __pycache__ |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,8 +45,9 @@ class WSDataset: | |
| # FIXME: this should be overridable with metadata in index.sqlite3 | ||
| _audio_file_keys = ["flac", "mp3", "sox", "wav", "m4a", "ogg", "wma", "opus", "audio"] | ||
|
|
||
| def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, key_folder: str | None = None, disable_memory_map: bool = False): | ||
| def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, key_folder: str | None = None, disable_memory_map: bool = False, skip_links: bool = False): | ||
| self.dataset_dir = self._resolve_path(dataset_dir) | ||
| self.skip_links = skip_links | ||
|
|
||
| self.index = None | ||
| self.segmented = False | ||
|
|
@@ -79,7 +80,8 @@ def __init__(self, dataset_dir: str | Path, include_in_progress: bool = True, ke | |
| self._open_shards = {} | ||
| self._linked_datasets = {} | ||
|
|
||
| self._register_wsds_links() | ||
| if not skip_links: | ||
| self._register_wsds_links() | ||
|
|
||
| def enable_filter(self, filter_name: str, filter_df: pl.DataFrame): | ||
| """ | ||
|
|
@@ -94,11 +96,6 @@ def enable_filter(self, filter_name: str, filter_df: pl.DataFrame): | |
|
|
||
| self._filter_dfs[filter_name] = filter_df | ||
|
|
||
| rows_satisfying_filter = filter_df.sum().item() | ||
| print( | ||
| f"Filter enabled on dataset {repr(self)}. Rows satisfying the filter: {rows_satisfying_filter} / {len(filter_df)}" | ||
| ) | ||
|
|
||
| # | ||
| # Accessing samples randomly and sequentially | ||
| # | ||
|
|
@@ -187,7 +184,7 @@ def sequential_from(self, sample, max_N=None): | |
| if self._filter_dfs is not None: | ||
| # We need to know the global shard offset to know what filter values to use for the sample | ||
| shard_global_offset = self.index.query( | ||
| "SELECT global_offset FROM shards WHERE shard = ?", shard_name | ||
| "SELECT global_offset FROM shards WHERE shard = ?", shard_name[1] | ||
| ).fetchone()[0] | ||
|
|
||
| while i < max_N: | ||
|
|
@@ -247,7 +244,8 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None): | |
| continue | ||
| subdir, field = self.fields[col] | ||
| assert col == field, "renamed fields are not supported in SQL queries yet" | ||
| subdirs[subdir].append(field) | ||
| if field not in subdirs[subdir]: | ||
| subdirs[subdir].append(field) | ||
| exprs.append(expr) | ||
|
|
||
| # If only __key__ is in the query, we need to load shards from at least one subdir | ||
|
|
@@ -265,12 +263,23 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None): | |
| if shard_subsample != 1: | ||
| shard_list = rng.sample(shard_list, int(len(shard_list) * shard_subsample)) | ||
|
|
||
| # TODO: Not sure if we want to drop the columns. I think previously we | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could apply the renaming on SQL as well which would make it consistent with the non-SQL API. I think most of the issues with duplicate fields were fixed recently when we added the select in this line: Do you remember which duplicate columns are causing you headaches?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It was
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, interesting. We need to fix this then. |
||
| # did some renaming, but that might also be confusing. Maybe we can put | ||
| # this behind a flag. | ||
| cols_seen: set[str] = set() | ||
| deduplicated_subdirs: dict[str, list[str]] = {} | ||
| for subdir, fields in subdirs.items(): | ||
| unique_fields = [f for f in fields if f not in cols_seen] | ||
| if unique_fields: | ||
| deduplicated_subdirs[subdir] = unique_fields | ||
| cols_seen.update(unique_fields) | ||
|
|
||
| row_merge = [] | ||
| subdir_samples = {} | ||
| missing = defaultdict(list) | ||
| for shard in shard_list: | ||
| col_merge = [] | ||
| for subdir, fields in subdirs.items(): | ||
| for subdir, fields in deduplicated_subdirs.items(): | ||
| shard_path = self.get_shard_path(subdir, shard) | ||
| if shard_path.exists(): | ||
| df = scan_ipc(shard_path, glob=False).select(fields) | ||
|
|
@@ -300,7 +309,54 @@ def _parse_sql_queries_polars(self, *queries, shard_subsample=1, rng=None): | |
| f"No usable shards found (columns: {', '.join(subdirs)}) for dataset in: {str(self.dataset_dir)}" | ||
| ) | ||
|
|
||
| return exprs, pl.concat(row_merge).select(exprs) | ||
| def _common_dtype(col_name: str, a: pl.DataType, b: pl.DataType) -> pl.DataType: | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. My guess would be that this is about some shards having a null type because all the samples were None for a column? Would
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The issue is that some metrics have shards stored as both
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll try |
||
| if a == b: | ||
| return a | ||
| if a == pl.Null: | ||
| return b | ||
| if b == pl.Null: | ||
| return a | ||
|
|
||
| if not (a.is_numeric() and b.is_numeric()): | ||
| raise TypeError(f"Cannot reconcile dtypes for column {col_name!r}: {a} vs {b}.") | ||
| # Make numeric dtypes consistent across shards so diagonal concat doesn't fail. | ||
| if a.is_float() or b.is_float(): | ||
| return pl.Float64 | ||
| if a.is_unsigned_integer() and b.is_unsigned_integer(): | ||
| return pl.UInt64 | ||
| if a.is_signed_integer() and b.is_signed_integer(): | ||
| return pl.Int64 | ||
|
|
||
| if a == pl.UInt64 or b == pl.UInt64: | ||
| raise TypeError(f"Cannot safely coerce column {col_name!r}: mixing UInt64 with signed integer ({a} vs {b})") | ||
|
|
||
| return pl.Int64 | ||
|
|
||
|
|
||
| def _cast_lazyframe_to_schema(lf: pl.LazyFrame, target_dtypes: dict[str, pl.DataType]) -> pl.LazyFrame: | ||
| schema = lf.collect_schema() | ||
| exprs = [] | ||
| for col_name, target_dtype in target_dtypes.items(): | ||
| if col_name not in schema: | ||
| continue | ||
| if schema[col_name] == target_dtype: | ||
| continue | ||
| exprs.append(pl.col(col_name).cast(target_dtype).alias(col_name)) | ||
| if not exprs: | ||
| return lf | ||
| return lf.with_columns(exprs) | ||
|
|
||
| target_dtypes: dict[str, pl.DataType] = {} | ||
| for lf in row_merge: | ||
| for col_name, dtype in lf.collect_schema().items(): | ||
| target_dtypes[col_name] = ( | ||
| _common_dtype(col_name, target_dtypes[col_name], dtype) | ||
| if col_name in target_dtypes else | ||
| dtype | ||
| ) | ||
| row_merge = [_cast_lazyframe_to_schema(lf, target_dtypes) for lf in row_merge] | ||
|
|
||
| return exprs, pl.concat(row_merge, how="diagonal").select(exprs) | ||
|
|
||
| def _check_for_subsampling(self, shard_subsample): | ||
| if shard_subsample is None: | ||
|
|
@@ -423,7 +479,13 @@ def get_linked_shard(self, link, shard_name): | |
| loader_class = link["loader"] | ||
| if isinstance(loader_class, list): | ||
| loader_mod, loader = loader_class | ||
| loader_module = importlib.import_module(loader_mod) | ||
|
|
||
| try: | ||
| loader_module = importlib.import_module(loader_mod) | ||
| except ImportError: | ||
| loader_mod = loader_mod.replace("hume_wsds", "wsds") | ||
| loader_module = importlib.import_module(loader_mod) | ||
|
|
||
| loader_class = getattr(loader_module, loader) | ||
|
|
||
| return loader_class.from_link( | ||
|
|
||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This unfortunately makes the test as expensive as loading the data with How do you use this downstream?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have some datasets where certain transcripts are only computed partially. I could also catch the exception in my code, but I felt it is a bit counter intuitive if
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Feels like in most cases you want to do something like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do the more thorough check if |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@rashishhume do you maybe know what's going on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deleted this since, it caused a lot of log output at the start of my training jobs. Maybe it could make more sense to log this stuff higher up.