Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions docs/source/audio/configs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,159 @@ An example train dataset in Lhotse shar format can be configured as follows:
A configuration file with Lhotse shar format can found in the `SSL pretraining example configuration <https://github.com/NVIDIA/NeMo/blob/main/examples/audio/conf/flow_matching_generative_ssl_pretraining.yaml>`_.


Dataset Reweighting with Temperature
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When combining multiple datasets using nested ``input_cfg`` groups, you can control the sampling distribution using the ``reweight_temperature`` parameter. This feature allows you to balance dataset sampling without manually recalculating weights **when adding or removing datasets**.

The temperature scaling formula is:

.. math::

\hat{w}_i = \frac{w_i^{\tau}}{\sum_{j} w_j^{\tau}}

where :math:`w_i` is the original weight of dataset :math:`i`, :math:`\tau` is the temperature, and :math:`\hat{w}_i` is the normalized sampling probability.

**How Temperature Works:**

- ``temperature = 1.0``: Preserves original weight ratios (neutral, no reweighting)
- ``temperature = 0.0``: Equalizes all datasets (each gets equal probability regardless of original weights)
- ``0 < temperature < 1.0``: Over-samples smaller datasets relative to larger ones
- ``temperature > 1.0``: Amplifies differences between dataset weights

**Configuration Options:**

The ``reweight_temperature`` parameter supports flexible formats that are automatically standardized:

1. **Scalar value** (applied to all nesting levels, warning logged):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: 0.5 # Applied to all levels, warning logged
input_cfg:
- type: group
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1
weight: 900
- type: lhotse_shar
shar_path: /path/to/dataset2
weight: 100
- type: lhotse_shar
shar_path: /path/to/dataset3
weight: 200
- type: nemo_tarred
manifest_filepath: /path/to/dataset4/manifest.json
tarred_audio_filepath: /path/to/dataset4/audio.tar
weight: 300

2. **List matching maximum nesting depth** (one temperature per level):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0, 0.0] # Level 1: preserve ratios, Level 2: equalize
input_cfg:
- type: group
weight: 0.7
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1
weight: 600
- type: lhotse_shar
shar_path: /path/to/dataset2
weight: 400
- type: group
weight: 0.3
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset3
weight: 100

3. **List shorter than maximum nesting depth** (extended by repeating last value, warning logged):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0] # Extended to [1.0, 1.0] for 2 levels
input_cfg:
- type: group
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1

4. **List longer than maximum nesting depth** (trimmed to match, warning logged):

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0, 0.5, 0.0] # Trimmed to [1.0] for 1 level
input_cfg:
- type: lhotse_shar
shar_path: /path/to/dataset1
weight: 100
Comment on lines +203 to +226
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should support 1 and 2. 3 and 4 should throw an error


**Maximum Nesting Depth Calculation:**

The maximum nesting depth is calculated as the maximum depth of ``input_cfg`` keys in the configuration. Sibling groups at the same level share the same temperature value.

.. code-block:: yaml

# This has maximum nesting depth = 2
input_cfg: # Level 1
- type: group
input_cfg: # Level 2
- type: lhotse_shar
- type: group # Same level as above (sibling)
input_cfg: # Level 2 (same as above)
- type: lhotse_shar

**Example: Balancing Multiple Task Groups**

.. code-block:: yaml

train_ds:
use_lhotse: true
reweight_temperature: [1.0, 0.0] # Level 1: Preserve task ratios, Level 2: Equalize within tasks
input_cfg:
- type: group
weight: 0.7
tags:
task: asr
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/asr1/manifest.json
tarred_audio_filepath: /path/to/asr1/audio.tar
weight: 600 # Large dataset
- type: nemo_tarred
manifest_filepath: /path/to/asr2/manifest.json
tarred_audio_filepath: /path/to/asr2/audio.tar
weight: 100 # Small dataset (will be upsampled with temp=0.0)
- type: group
weight: 0.3
tags:
task: ast
input_cfg:
- type: nemo_tarred
manifest_filepath: /path/to/ast1/manifest.json
tarred_audio_filepath: /path/to/ast1/audio.tar
weight: 50
- type: nemo_tarred
manifest_filepath: /path/to/ast2/manifest.json
tarred_audio_filepath: /path/to/ast2/audio.tar
weight: 200

In this example:

- Level 1 temperature is ``1.0``: The 70/30 split between ASR and AST groups is preserved
- Level 2 temperature is ``0.0``: Within each group, all datasets are sampled equally regardless of their original weights


Model Architecture Configuration
--------------------------------
Each configuration file should describe the model architecture being used for the experiment.
Expand Down
1 change: 1 addition & 0 deletions examples/tts/conf/magpietts/magpietts_lhotse.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ model:
shuffle: true
num_workers: 6
pin_memory: true
reweight_temperature: [1.0, 0.0] # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the documentation for what this parameter does? And can we put it in the same place as the other arguments?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • we should support this value being a float (same temperature for all levels) and also define what happens when this list is shorter than the number of nested groups in input_cfg (e.g., the last temperature is re-used for each next level).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where is the documentation for what this parameter does? And can we put it in the same place as the other arguments?

good point. The only place where this new param fits in seems here: https://github.com/NVIDIA-NeMo/NeMo/blob/main/docs/source/audio/configs.rst#lhotse-cutset. i will add description there.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • we should support this value being a float (same temperature for all levels) and also define what happens when this list is shorter than the number of nested groups in input_cfg (e.g., the last temperature is re-used for each next level).

i intentionally designed this with a strict List requirement to prioritize explicit configuration over flexibility. This ensures users are fully aware of their nested data structures and define temperatures intentionally.

but I agree that we need to validate that the number of nested groups matches the list length. I'll look into how to enforce that check.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

addressed all above comments.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you also need to add this to dataloader.py in LhotseDataLoadingConfig otherwise it won't get propagated to cutset.py functions. We should also have a test that uses this option through get_lhotse_dataloader_from_config to check it works.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good point. indeed, my test using get_lhotse_dataloader_from_config failed to pass this new param to propagate_attrs.

Fixed now.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my understanding, correct me if I am wrong: Level 1 (Language): preserves the weight, Level 2 (datasets within each language): equalizes.
Based on this understanding, don't we want to normalize the weights based on number of shar files/size of the datasets at level 2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. But this is an example demonstrating how to use this new param. In practice, we should override this one based on your needs.


input_cfg:
- type: lhotse_shar
Expand Down
138 changes: 136 additions & 2 deletions nemo/collections/common/data/lhotse/cutset.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from functools import partial
from itertools import repeat
from pathlib import Path
from typing import KeysView, Mapping, Sequence, Tuple, Union
from typing import KeysView, List, Mapping, Sequence, Tuple, Union

import numpy as np
import omegaconf
Expand Down Expand Up @@ -50,6 +50,39 @@
from nemo.collections.common.parts.preprocessing.manifest import get_full_path


def temperature_reweighting(weights: List[Union[float, int]], temperature: float = 1.0) -> List[float]:
"""
Apply temperature scaling to dataset weights and normalize.

Formula: normalized_weight_i = (w_i ^ temperature) / sum(w_j ^ temperature)

Args:
weights: List of dataset weights (can be hours, sample counts, or probabilities).
Values can be any positive float/int, not limited to [0, 1].
temperature: Scaling factor.
- 1.0: preserves original weight ratios
- 0.0: equalizes all weights (w^0 = 1)
- <1.0: oversamples smaller datasets
- >1.0: amplifies weight differences

Returns:
Normalized weights that sum to 1.0

Example:
>>> temperature_reweighting([197, 2159], temperature=1.0) # hours
[0.0836, 0.9164] # preserves ratio
>>> temperature_reweighting([197, 2159], temperature=0.0) # equalize
[0.5, 0.5]
"""
if len(weights) == 0:
return []
weights = np.asarray(weights)
if np.any(weights <= 0):
raise ValueError(f"All weights must be positive (> 0), got: {weights.tolist()}")
weights = weights**temperature
return (weights / weights.sum()).tolist()


def read_cutset_from_config(config: Union[DictConfig, dict]) -> Tuple[CutSet, bool]:
"""
Reads NeMo configuration and creates a CutSet either from Lhotse or NeMo manifests.
Expand Down Expand Up @@ -210,7 +243,51 @@ def read_dataset_config(config) -> tuple[CutSet, bool]:
"force_map_dataset": config.get("force_map_dataset", False),
"force_iterable_dataset": config.get("force_iterable_dataset", False),
"slice_length": config.get("slice_length", None),
# Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
"reweight_temperature": config.get("reweight_temperature", None),
}

# Standardize reweight_temperature to match the nesting depth
if propagate_attrs["reweight_temperature"] is not None:
expected_length = count_input_cfg_levels(config)
reweight_temp = propagate_attrs["reweight_temperature"]

# Case 1: Scalar value - broadcast to all levels
if isinstance(reweight_temp, (int, float)):
propagate_attrs["reweight_temperature"] = [float(reweight_temp)] * expected_length
logging.warning(
f"reweight_temperature is a scalar ({reweight_temp}), broadcasting to all {expected_length} levels. "
f"Expanded to: {propagate_attrs['reweight_temperature']}"
)
else:
# Case 2:Convert to list if needed (e.g., from ListConfig)
reweight_temp = list(reweight_temp)
actual_length = len(reweight_temp)

if actual_length == expected_length:
# Case 2.1: Exact match - no modification needed
propagate_attrs["reweight_temperature"] = reweight_temp
elif actual_length < expected_length:
# Case 2.2: Too short - extend by repeating last value
last_value = reweight_temp[-1] if reweight_temp else 1.0
extended = reweight_temp + [last_value] * (expected_length - actual_length)
propagate_attrs["reweight_temperature"] = extended
logging.warning(
f"reweight_temperature list is shorter than nesting depth: "
f"got {actual_length} values for {expected_length} levels. "
f"Extending by repeating last value ({last_value}). "
f"Expanded to: {extended}"
)
else:
# Case 2.3: Too long - trim to max depth
trimmed = reweight_temp[:expected_length]
propagate_attrs["reweight_temperature"] = trimmed
logging.warning(
f"reweight_temperature list is longer than nesting depth: "
f"got {actual_length} values for {expected_length} levels. "
f"Trimming extra values. Using: {trimmed}"
)

cuts, is_tarred = parse_and_combine_datasets(config.input_cfg, propagate_attrs=propagate_attrs)
return cuts, is_tarred

Expand Down Expand Up @@ -342,6 +419,50 @@ def attach_tags(cut, tags: dict):
return cut


def count_input_cfg_levels(config: Union[DictConfig, dict]) -> int:
"""
Compute the maximum nesting depth of 'input_cfg' keys in the configuration.

Each 'input_cfg' represents one level of nesting that consumes one temperature
value from reweight_temperature. Since sibling groups at the same level share
the same temperature (due to propagate_attrs.copy()), we count max depth,
not total occurrences.

Args:
config: Configuration dictionary that may contain nested 'input_cfg' keys.

Returns:
Maximum nesting depth of 'input_cfg' keys.

Example:
>>> config = {
... "input_cfg": [
... {"type": "group", "input_cfg": [{"type": "nemo"}]},
... {"type": "group", "input_cfg": [{"type": "nemo"}]},
... ]
... }
>>> count_input_cfg_levels(config)
2
"""

def _max_depth(obj) -> int:
if isinstance(obj, (dict, DictConfig)):
depths = []
for key, val in obj.items():
if key == "input_cfg":
# Found input_cfg: this level counts as 1 + max depth of children
depths.append(1 + _max_depth(val))
else:
depths.append(_max_depth(val))
return max(depths, default=0)
elif isinstance(obj, (list, ListConfig)):
# For lists, find the max depth across all items (siblings)
return max((_max_depth(item) for item in obj), default=0)
return 0

return _max_depth(config)


@data_type_parser("group")
def parse_and_combine_datasets(
config_list: Union[list[DictConfig], ListConfig], propagate_attrs: dict
Expand All @@ -351,6 +472,14 @@ def parse_and_combine_datasets(
weights = []
tarred_status = []

# Extract the temperature for re-weighting datasets.
if not propagate_attrs["reweight_temperature"]:
temperature = 1.0
next_temperatures = None
else:
temperature, *next_temperatures = propagate_attrs["reweight_temperature"]
propagate_attrs["reweight_temperature"] = next_temperatures
Comment on lines +475 to +481
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the length of the reweight_temperature list is not equal to the number of datasets in the config?
What happens if reweight_temperature is just a float?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if the length of the reweight_temperature list is not equal to the number of datasets in the config?

The reweight_temperature list length only needs to match the number of nested groups.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens if reweight_temperature is just a float?

similar question as @pzelasko, pls check my answer above.

If both of you voted for supporting a float value, meaning apply the same temperature to all levels of nested group. i can make changes accordingly.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added support of scalar.


if isinstance(config_list, (str, Path)):
# Resolve local filepath /path/to/input_cfg.yaml or remote url s3://bucket/path/to/input_cfg.yaml into config contents if needed.
config_list = OmegaConf.create(load_yaml(config_list))
Expand Down Expand Up @@ -394,9 +523,14 @@ def parse_and_combine_datasets(
), "Missing dataset weight. When weighting datasets, every dataset must have a specified weight."

if len(cuts) > 1:
if not weights:
reweights = None
else:
reweights = temperature_reweighting(weights, temperature=temperature)

cuts = mux(
*cuts,
weights=weights if weights else None,
weights=reweights,
max_open_streams=propagate_attrs["max_open_streams"],
seed=propagate_attrs["shard_seed"],
force_finite=propagate_attrs["force_finite"] or propagate_attrs["metadata_only"],
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import warnings
from dataclasses import dataclass
from functools import partial
from typing import Any, Optional, Sequence, Union
from typing import Any, List, Optional, Sequence, Union

import numpy as np
import torch
Expand Down Expand Up @@ -103,6 +103,10 @@
shard_seed: int | str = "trng"
max_open_streams: int | None = None
cuda_expandable_segments: bool = True
# Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
# Can be a scalar (applied to all levels) or a list (one per nesting level).
# If list length doesn't match nesting depth, it will be extended or trimmed with a warning.
reweight_temperature: Any = None # float | int | list[float] | None = None
# e. Multi-config related options.
# Setting multi_config=True will scan the config for keys with DictConfig values,
# create a separate sampler for each, and fuse the samplers according to sampler_fusion.
Expand Down
Loading
Loading