Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions docs/source/audio/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,159 @@ An example train dataset in Lhotse shar format can be configured as follows:
A configuration file with Lhotse shar format can found in the `SSL pretraining example configuration <https://github.com/NVIDIA/NeMo/blob/main/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml>`_.


Dataset Reweighting with Temperature
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When combining multiple datasets using nested ``input_cfg`` groups, you can control the sampling distribution using the ``reweight_temperature`` parameter. This feature allows you to balance dataset sampling without manually recalculating weights **when adding or removing datasets**.

The temperature scaling formula is:

.. math::

\hat{w}_i = \frac{w_i^{\tau}}{\sum_{j} w_j^{\tau}}

where :math:`w_i` is the original weight of dataset :math:`i`, :math:`\tau` is the temperature, and :math:`\hat{w}_i` is the normalized sampling probability.

**How Temperature Works:**

- ``temperature = 1.0``: Preserves original weight ratios (neutral, no reweighting)
- ``temperature = 0.0``: Equalizes all datasets (each gets equal probability regardless of original weights)
- ``0 < temperature < 1.0``: Over-samples smaller datasets relative to larger ones
- ``temperature > 1.0``: Amplifies differences between dataset weights

**Configuration Options:**

The ``reweight_temperature`` parameter supports flexible formats that are automatically standardized:

1. **Scalar value** (applied to all nesting levels, warning logged):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: 0.5 # Applied to all levels, warning logged
input_cfg:
- type: group
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1
weight: 900
- type: lhotse_shar
shar_path: /path/to/dataset2
weight: 100
- type: lhotse_shar
shar_path: /path/to/dataset3
weight: 200
- type: nemo_tarred
manifest_filepath: /path/to/dataset4/manifest.json
tarred_audio_filepath: /path/to/dataset4/audio.tar
weight: 300

2. **List matching maximum nesting depth** (one temperature per level):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0, 0.0] # Level 1: preserve ratios, Level 2: equalize
input_cfg:
- type: group
weight: 0.7
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1
weight: 600
- type: lhotse_shar
shar_path: /path/to/dataset2
weight: 400
- type: group
weight: 0.3
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset3
weight: 100

3. **List shorter than maximum nesting depth** (extended by repeating last value, warning logged):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0] # Extended to [1.0, 1.0] for 2 levels
input_cfg:
- type: group
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1

4. **List longer than maximum nesting depth** (trimmed to match, warning logged):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0, 0.5, 0.0] # Trimmed to [1.0] for 1 level
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1
weight: 100

**Maximum Nesting Depth Calculation:**

The maximum nesting depth is calculated as the maximum depth of ``input_cfg`` keys in the configuration. Sibling groups at the same level share the same temperature value.

.. code-block:: yaml

# This has maximum nesting depth = 2
input_cfg: # Level 1
- type: group
input_cfg: # Level 2
- type: lhotse_shar
- type: group # Same level as above (sibling)
input_cfg: # Level 2 (same as above)
- type: lhotse_shar

**Example: Balancing Multiple Task Groups**

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0, 0.0] # Level 1: Preserve task ratios, Level 2: Equalize within tasks
input_cfg:
- type: group
weight: 0.7
tags:
task: asr
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/asr1/manifest.json
tarred_audio_filepath: /path/to/asr1/audio.tar
weight: 600 # Large dataset
- type: nemo_tarred
manifest_filepath: /path/to/asr2/manifest.json
tarred_audio_filepath: /path/to/asr2/audio.tar
weight: 100 # Small dataset (will be upsampled with temp=0.0)
- type: group
weight: 0.3
tags:
task: ast
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/ast1/manifest.json
tarred_audio_filepath: /path/to/ast1/audio.tar
weight: 50
- type: nemo_tarred
manifest_filepath: /path/to/ast2/manifest.json
tarred_audio_filepath: /path/to/ast2/audio.tar
weight: 200

In this example:

- Level 1 temperature is ``1.0``: The 70/30 split between ASR and AST groups is preserved
- Level 2 temperature is ``0.0``: Within each group, all datasets are sampled equally regardless of their original weights


Model Architecture Configuration
--------------------------------
Each configuration file should describe the model architecture being used for the experiment.
Expand Down
1 change: 1 addition & 0 deletions examples/tts/conf/magpietts/magpietts_lhotse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ model:
shuffle: true
num_workers: 6
pin_memory: true
reweight_temperature: [1.0, 0.0] # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my understanding, correct me if I am wrong: Level 1 (Language): preserves the weight, Level 2 (datasets within each language): equalizes.
Based on this understanding, don't we want to normalize the weights based on number of shar files/size of the datasets at level 2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. But this is an example demonstrating how to use this new param. In practice, we should override this one based on your needs.


input_cfg:
- type: lhotse_shar
Expand Down
138 changes: 136 additions & 2 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from functools import partial
from itertools import repeat
from pathlib import Path
from typing import KeysView, Mapping, Sequence, Tuple, Union
from typing import KeysView, List, Mapping, Sequence, Tuple, Union

import numpy as np
import omegaconf
Expand Down Expand Up @@ -50,6 +50,39 @@
from nemo.collections.common.parts.preprocessing.manifest import get_full_path


def temperature_reweighting(weights: List[Union[float, int]], temperature: float = 1.0) -> List[float]:
"""
Apply temperature scaling to dataset weights and normalize.

Formula: normalized_weight_i = (w_i ^ temperature) / sum(w_j ^ temperature)

Args:
weights: List of dataset weights (can be hours, sample counts, or probabilities).
Values can be any positive float/int, not limited to [0, 1].
temperature: Scaling factor.
- 1.0: preserves original weight ratios
- 0.0: equalizes all weights (w^0 = 1)
- <1.0: oversamples smaller datasets
- >1.0: amplifies weight differences

Returns:
Normalized weights that sum to 1.0

Example:
>>> temperature_reweighting([197, 2159], temperature=1.0) # hours
[0.0836, 0.9164] # preserves ratio
>>> temperature_reweighting([197, 2159], temperature=0.0) # equalize
[0.5, 0.5]
"""
if len(weights) == 0:
return []
weights = np.asarray(weights)
if np.any(weights <= 0):
raise ValueError(f"All weights must be positive (> 0), got: {weights.tolist()}")
weights = weights**temperature
return (weights / weights.sum()).tolist()


def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]:
"""
Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests.
Expand Down Expand Up @@ -210,7 +243,51 @@ def read_dataset_config(config) -> tuple[CutSet, bool]:
"force_map_dataset": config.get("force_map_dataset", False),
"force_iterable_dataset": config.get("force_iterable_dataset", False),
"slice_length": config.get("slice_length", None),
# Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
"reweight_temperature": config.get("reweight_temperature", None),
}

# Standardize reweight_temperature to match the nesting depth
if propagate_attrs["reweight_temperature"] is not None:
expected_length = count_input_cfg_levels(config)
reweight_temp = propagate_attrs["reweight_temperature"]

# Case 1: Scalar value - broadcast to all levels
if isinstance(reweight_temp, (int, float)):
propagate_attrs["reweight_temperature"] = [float(reweight_temp)] * expected_length
logging.warning(
f"reweight_temperature is a scalar ({reweight_temp}), broadcasting to all {expected_length} levels. "
f"Expanded to: {propagate_attrs['reweight_temperature']}"
)
else:
# Case 2: Convert to list if needed (e.g., from ListConfig)
reweight_temp = list(reweight_temp)
actual_length = len(reweight_temp)

if actual_length == expected_length:
# Case 2.1: Exact match - no modification needed
propagate_attrs["reweight_temperature"] = reweight_temp
elif actual_length < expected_length:
# Case 2.2: Too short - extend by repeating last value
last_value = reweight_temp[-1] if reweight_temp else 1.0
extended = reweight_temp + [last_value] * (expected_length - actual_length)
propagate_attrs["reweight_temperature"] = extended
logging.warning(
f"reweight_temperature list is shorter than nesting depth: "
f"got {actual_length} values for {expected_length} levels. "
f"Extending by repeating last value ({last_value}). "
f"Expanded to: {extended}"
)
else:
# Case 2.3: Too long - trim to max depth
trimmed = reweight_temp[:expected_length]
propagate_attrs["reweight_temperature"] = trimmed
logging.warning(
f"reweight_temperature list is longer than nesting depth: "
f"got {actual_length} values for {expected_length} levels. "
f"Trimming extra values. Using: {trimmed}"
)

cuts, is_tarred = parse_and_combine_datasets(config.input_cfg, propagate_attrs=propagate_attrs)
return cuts, is_tarred

Expand Down Expand Up @@ -342,6 +419,50 @@ def attach_tags(cut, tags: dict):
return cut


def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
"""
Compute the maximum nesting depth of 'input_cfg' keys in the configuration.

Each 'input_cfg' represents one level of nesting that consumes one temperature
value from reweight_temperature. Since sibling groups at the same level share
the same temperature (due to propagate_attrs.copy()), we count max depth,
not total occurrences.

Args:
config: Configuration dictionary that may contain nested 'input_cfg' keys.

Returns:
Maximum nesting depth of 'input_cfg' keys.

Example:
>>> config = {
... "input_cfg": [
... {"type": "group", "input_cfg": [{"type": "nemo"}]},
... {"type": "group", "input_cfg": [{"type": "nemo"}]},
... ]
... }
>>> count_input_cfg_levels(config)
2
"""

def _max_depth(obj) -> int:
if isinstance(obj, (dict, DictConfig)):
depths = []
for key, val in obj.items():
if key == "input_cfg":
# Found input_cfg: this level counts as 1 + max depth of children
depths.append(1 + _max_depth(val))
else:
depths.append(_max_depth(val))
return max(depths, default=0)
elif isinstance(obj, (list, ListConfig)):
# For lists, find the max depth across all items (siblings)
return max((_max_depth(item) for item in obj), default=0)
return 0

return _max_depth(config)


@data_type_parser("group")
def parse_and_combine_datasets(
config_list: Union[list[DictConfig], ListConfig], propagate_attrs: dict
Expand All @@ -351,6 +472,14 @@ def parse_and_combine_datasets(
weights = []
tarred_status = []

# Extract the temperature for re-weighting datasets.
if not propagate_attrs["reweight_temperature"]:
temperature = 1.0
next_temperatures = None
else:
temperature, *next_temperatures = propagate_attrs["reweight_temperature"]
propagate_attrs["reweight_temperature"] = next_temperatures

if isinstance(config_list, (str, Path)):
# Resolve local filepath /path/to/input_cfg.yaml or remote url s3://bucket/path/to/input_cfg.yaml into config contents if needed.
config_list = OmegaConf.create(load_yaml(config_list))
Expand Down Expand Up @@ -394,9 +523,14 @@ def parse_and_combine_datasets(
), "Missing dataset weight. When weighting datasets, every dataset must have a specified weight."

if len(cuts) > 1:
if not weights:
reweights = None
else:
reweights = temperature_reweighting(weights, temperature=temperature)

cuts = mux(
*cuts,
weights=weights if weights else None,
weights=reweights,
max_open_streams=propagate_attrs["max_open_streams"],
seed=propagate_attrs["shard_seed"],
force_finite=propagate_attrs["force_finite"] or propagate_attrs["metadata_only"],
Expand Down
4 changes: 4 additions & 0 deletions nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ class LhotseDataLoadingConfig:
shard_seed: int | str = "trng"
max_open_streams: int | None = None
cuda_expandable_segments: bool = True
# Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
# Can be a scalar (applied to all levels) or a list (one per nesting level).
# If list length doesn't match nesting depth, it will be extended or trimmed with a warning.
reweight_temperature: Any = None # float | int | list[float] | None = None
# e. Multi-config related options.
# Setting multi_config=True will scan the config for keys with DictConfig values,
# create a separate sampler for each, and fuse the samplers according to sampler_fusion.
Expand Down
Loading
Loading