Skip to content

Conversation

@XuesongYang
Copy link
Collaborator

@XuesongYang XuesongYang commented Dec 17, 2025

Summary

This PR introduces temperature-based dataset reweighting to the Lhotse dataloader, enabling dynamic adjustment of sampling probabilities across datasets without manually recalculating weights. It removes the need to update dataset weights when adding or removing datasets. Simply specify the desired temperature, and the weights are automatically adjusted at runtime.

Key Changes

1. New temperature_reweighting() function

Applies temperature scaling to dataset weights using the formula:

ŵᵢ = wᵢ^τ / Σⱼ wⱼ^τ.

  • τ = 1.0: Preserves original weight ratios (neutral)
  • τ = 0.0: Equalizes all datasets regardless of original weights
  • 0 < τ < 1.0: Over-samples smaller datasets relative to larger ones

2. Flexible reweight_temperature config option

Supports multiple input formats with automatic normalization:

  • Scalar: Broadcasts to all nesting levels (with warning)
  • List matching depth: Applied per-level as specified
  • List too short: Extended by repeating last value (with warning)
  • List too long: Trimmed to max depth (with warning)

Example - preserve top-level ratios but equalize within sub-groups:

reweight_temperature: [1.0, 0.0]  # Level 1: preserve, Level 2: equalize

3. Comprehensive test coverage

  • Unit tests for temperature_reweighting() and count_input_cfg_levels()
  • Integration tests for dataloader with various temperature configurations
  • Validation tests for scalar/list normalization behavior

Documentation

See the new "Dataset Reweighting with Temperature" section in the audio configuration docs: docs/source/audio/configs.rst

Copilot AI review requested due to automatic review settings December 17, 2025 00:30
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces temperature-based dataset reweighting to the Lhotse dataloader, enabling dynamic adjustment of sampling probabilities across datasets through a configurable temperature parameter. This eliminates the need for manual weight recalculation when combining datasets with different sizes.

Key changes:

  • New temperature_reweighting() function that applies temperature scaling to weights using the formula (w_i ^ temp) / sum(w_j ^ temp)
  • New reweight_temperature configuration option that supports hierarchical temperature application across nested dataset groups
  • Comprehensive test suite with 19 tests covering various input types and edge cases

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 4 comments.

File Description
nemo/collections/common/data/lhotse/cutset.py Adds the temperature_reweighting() function and integrates it into the dataset loading pipeline with temperature propagation through nested groups
examples/tts/conf/magpietts/magpietts_lhotse.yaml Adds example configuration demonstrating hierarchical temperature usage with [1.0, 0.0]
tests/collections/common/test_lhotse_temperature_reweighting.py Adds comprehensive unit and integration tests for the temperature reweighting functionality

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Xuesong Yang and others added 6 commits January 12, 2026 21:00
…ights predeifined in train_ds.dataset.input_cfg YAML configs. This feature would save the effort of flattening the dataset distribution every time when adding new datasets.

Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>
    Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
    Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
    Signed-off-by: Xuesong Yang <16880-xueyang@users.noreply.gitlab-master.nvidia.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
    Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…onfig. Otherwise, it would not be passed to propagate_attrs.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
…t length could be shorter or longer than the max depth of recursion group. added tests.

Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 6 out of 6 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

nemo/collections/common/data/lhotse/dataloader.py:19

  • Import of 'List' is not used.
from typing import Any, Optional, Sequence, Union

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Xuesong Yang <1646669+XuesongYang@users.noreply.github.com>
shuffle: true
num_workers: 6
pin_memory: true
reweight_temperature: [1.0, 0.0] # Temperature for re-weighting datasets. 1 is a neutral value. Lower temperature over-samples smaller datasets, and vice versa.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is my understanding, correct me if I am wrong: Level 1 (Language): preserves the weight, Level 2 (datasets within each language): equalizes.
Based on this understanding, don't we want to normalize the weights based on number of shar files/size of the datasets at level 2?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right. But this is an example demonstrating how to use this new param. In practice, we should override this one based on your needs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants