-
Notifications
You must be signed in to change notification settings - Fork 3.3k
[lhotse] Added support for re-weighting datasets with temperature on the fly. #15200
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 11 commits
c9770a2
0ca1f93
19ce594
ed87198
9787d7c
2ceb45e
18c83a0
2a60233
b7c5a0a
f6321f7
f524347
a4bdc54
734ef25
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -87,6 +87,7 @@ model: | |
| shuffle: true | ||
| num_workers: 6 | ||
| pin_memory: true | ||
| reweight_temperature: [1.0, 0.0] # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa. | ||
XuesongYang marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where is the documentation for what this parameter does? And can we put it in the same place as the other arguments?
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
good point. The only place where this new param fits in seems here: https://github.com/NVIDIA-NeMo/NeMo/blob/main/docs/source/audio/configs.rst#lhotse-cutset. i will add description there.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
i intentionally designed this with a strict but I agree that we need to validate that the number of nested groups matches the list length. I'll look into how to enforce that check.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed all above comments.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you also need to add this to
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good point. indeed, my test using Fixed now.
XuesongYang marked this conversation as resolved.
Show resolved
Hide resolved
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is my understanding, correct me if I am wrong: Level 1 (Language): preserves the weight, Level 2 (datasets within each language): equalizes.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you're right. But this is an example demonstrating how to use this new param. In practice, we should override this one based on your needs. |
||
|
|
||
| input_cfg: | ||
| - type: lhotse_shar | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,7 +20,7 @@ | |
| from functools import partial | ||
| from itertools import repeat | ||
| from pathlib import Path | ||
| from typing import KeysView, Mapping, Sequence, Tuple, Union | ||
| from typing import KeysView, List, Mapping, Sequence, Tuple, Union | ||
|
|
||
| import numpy as np | ||
| import omegaconf | ||
|
|
@@ -50,6 +50,39 @@ | |
| from nemo.collections.common.parts.preprocessing.manifest import get_full_path | ||
|
|
||
|
|
||
| def temperature_reweighting(weights: List[Union[float, int]], temperature: float = 1.0) -> List[float]: | ||
| """ | ||
| Apply temperature scaling to dataset weights and normalize. | ||
|
|
||
| Formula: normalized_weight_i = (w_i ^ temperature) / sum(w_j ^ temperature) | ||
|
|
||
| Args: | ||
| weights: List of dataset weights (can be hours, sample counts, or probabilities). | ||
| Values can be any positive float/int, not limited to [0, 1]. | ||
| temperature: Scaling factor. | ||
| - 1.0: preserves original weight ratios | ||
| - 0.0: equalizes all weights (w^0 = 1) | ||
| - <1.0: oversamples smaller datasets | ||
| - >1.0: amplifies weight differences | ||
|
|
||
| Returns: | ||
| Normalized weights that sum to 1.0 | ||
|
|
||
| Example: | ||
| >>> temperature_reweighting([197, 2159], temperature=1.0) # hours | ||
| [0.0836, 0.9164] # preserves ratio | ||
| >>> temperature_reweighting([197, 2159], temperature=0.0) # equalize | ||
| [0.5, 0.5] | ||
| """ | ||
| if len(weights) == 0: | ||
| return [] | ||
| weights = np.asarray(weights) | ||
| if np.any(weights <= 0): | ||
| raise ValueError(f"All weights must be positive (> 0), got: {weights.tolist()}") | ||
| weights = weights**temperature | ||
| return (weights / weights.sum()).tolist() | ||
|
|
||
|
|
||
| def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]: | ||
| """ | ||
| Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests. | ||
|
|
@@ -210,7 +243,51 @@ def read_dataset_config(config) -> tuple[CutSet, bool]: | |
| "force_map_dataset": config.get("force_map_dataset", False), | ||
| "force_iterable_dataset": config.get("force_iterable_dataset", False), | ||
| "slice_length": config.get("slice_length", None), | ||
| # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa. | ||
| "reweight_temperature": config.get("reweight_temperature", None), | ||
| } | ||
|
|
||
| # Standardize reweight_temperature to match the nesting depth | ||
| if propagate_attrs["reweight_temperature"] is not None: | ||
| expected_length = count_input_cfg_levels(config) | ||
| reweight_temp = propagate_attrs["reweight_temperature"] | ||
|
|
||
| # Case 1: Scalar value - broadcast to all levels | ||
| if isinstance(reweight_temp, (int, float)): | ||
| propagate_attrs["reweight_temperature"] = [float(reweight_temp)] * expected_length | ||
| logging.warning( | ||
| f"reweight_temperature is a scalar ({reweight_temp}), broadcasting to all {expected_length} levels. " | ||
| f"Expanded to: {propagate_attrs['reweight_temperature']}" | ||
| ) | ||
| else: | ||
| # Case 2:Convert to list if needed (e.g., from ListConfig) | ||
XuesongYang marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| reweight_temp = list(reweight_temp) | ||
| actual_length = len(reweight_temp) | ||
|
|
||
| if actual_length == expected_length: | ||
| # Case 2.1: Exact match - no modification needed | ||
| propagate_attrs["reweight_temperature"] = reweight_temp | ||
| elif actual_length < expected_length: | ||
| # Case 2.2: Too short - extend by repeating last value | ||
| last_value = reweight_temp[-1] if reweight_temp else 1.0 | ||
| extended = reweight_temp + [last_value] * (expected_length - actual_length) | ||
| propagate_attrs["reweight_temperature"] = extended | ||
| logging.warning( | ||
| f"reweight_temperature list is shorter than nesting depth: " | ||
| f"got {actual_length} values for {expected_length} levels. " | ||
| f"Extending by repeating last value ({last_value}). " | ||
| f"Expanded to: {extended}" | ||
| ) | ||
| else: | ||
| # Case 2.3: Too long - trim to max depth | ||
| trimmed = reweight_temp[:expected_length] | ||
| propagate_attrs["reweight_temperature"] = trimmed | ||
| logging.warning( | ||
| f"reweight_temperature list is longer than nesting depth: " | ||
| f"got {actual_length} values for {expected_length} levels. " | ||
| f"Trimming extra values. Using: {trimmed}" | ||
| ) | ||
|
|
||
| cuts, is_tarred = parse_and_combine_datasets(config.input_cfg, propagate_attrs=propagate_attrs) | ||
| return cuts, is_tarred | ||
|
|
||
|
|
@@ -342,6 +419,50 @@ def attach_tags(cut, tags: dict): | |
| return cut | ||
|
|
||
|
|
||
| def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int: | ||
| """ | ||
| Compute the maximum nesting depth of 'input_cfg' keys in the configuration. | ||
|
|
||
| Each 'input_cfg' represents one level of nesting that consumes one temperature | ||
| value from reweight_temperature. Since sibling groups at the same level share | ||
| the same temperature (due to propagate_attrs.copy()), we count max depth, | ||
| not total occurrences. | ||
|
|
||
| Args: | ||
| config: Configuration dictionary that may contain nested 'input_cfg' keys. | ||
|
|
||
| Returns: | ||
| Maximum nesting depth of 'input_cfg' keys. | ||
|
|
||
| Example: | ||
| >>> config = { | ||
| ... "input_cfg": [ | ||
| ... {"type": "group", "input_cfg": [{"type": "nemo"}]}, | ||
| ... {"type": "group", "input_cfg": [{"type": "nemo"}]}, | ||
| ... ] | ||
| ... } | ||
| >>> count_input_cfg_levels(config) | ||
| 2 | ||
| """ | ||
|
|
||
| def _max_depth(obj) -> int: | ||
| if isinstance(obj, (dict, DictConfig)): | ||
| depths = [] | ||
| for key, val in obj.items(): | ||
| if key == "input_cfg": | ||
| # Found input_cfg: this level counts as 1 + max depth of children | ||
| depths.append(1 + _max_depth(val)) | ||
| else: | ||
| depths.append(_max_depth(val)) | ||
| return max(depths, default=0) | ||
| elif isinstance(obj, (list, ListConfig)): | ||
| # For lists, find the max depth across all items (siblings) | ||
| return max((_max_depth(item) for item in obj), default=0) | ||
| return 0 | ||
|
|
||
| return _max_depth(config) | ||
|
|
||
|
|
||
| @data_type_parser("group") | ||
| def parse_and_combine_datasets( | ||
| config_list: Union[list[DictConfig], ListConfig], propagate_attrs: dict | ||
|
|
@@ -351,6 +472,14 @@ def parse_and_combine_datasets( | |
| weights = [] | ||
| tarred_status = [] | ||
|
|
||
| # Extract the temperature for re-weighting datasets. | ||
| if not propagate_attrs["reweight_temperature"]: | ||
| temperature = 1.0 | ||
| next_temperatures = None | ||
XuesongYang marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| else: | ||
| temperature, *next_temperatures = propagate_attrs["reweight_temperature"] | ||
| propagate_attrs["reweight_temperature"] = next_temperatures | ||
|
Comment on lines
+475
to
+481
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What happens if the length of the reweight_temperature list is not equal to the number of datasets in the config?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
The
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
similar question as @pzelasko, pls check my answer above. If both of you voted for supporting a float value, meaning apply the same temperature to all levels of nested group. i can make changes accordingly.
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added support of scalar. |
||
|
|
||
| if isinstance(config_list, (str, Path)): | ||
| # Resolve local filepath /path/to/input_cfg.yaml or remote url s3://bucket/path/to/input_cfg.yaml into config contents if needed. | ||
| config_list = OmegaConf.create(load_yaml(config_list)) | ||
|
|
@@ -394,9 +523,14 @@ def parse_and_combine_datasets( | |
| ), "Missing dataset weight. When weighting datasets, every dataset must have a specified weight." | ||
|
|
||
| if len(cuts) > 1: | ||
| if not weights: | ||
| reweights = None | ||
| else: | ||
| reweights = temperature_reweighting(weights, temperature=temperature) | ||
|
|
||
| cuts = mux( | ||
| *cuts, | ||
| weights=weights if weights else None, | ||
| weights=reweights, | ||
| max_open_streams=propagate_attrs["max_open_streams"], | ||
| seed=propagate_attrs["shard_seed"], | ||
| force_finite=propagate_attrs["force_finite"] or propagate_attrs["metadata_only"], | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should support 1 and 2. 3 and 4 should throw an error