Skip to content

Comments

Fix operator precedence bug in VLBO loss weights#107

Open
Mr-Neutr0n wants to merge 1 commit intoAILab-CVC:mainfrom
Mr-Neutr0n:fix/lvlb-weights-operator-precedence
Open

Fix operator precedence bug in VLBO loss weights#107
Mr-Neutr0n wants to merge 1 commit intoAILab-CVC:mainfrom
Mr-Neutr0n:fix/lvlb-weights-operator-precedence

Conversation

@Mr-Neutr0n
Copy link

Summary

  • Fix operator precedence bug in lvlb_weights computation for x0 parameterization (lvdm/models/ddpm3d.py, line 159). The expression (2. * 1 - torch.Tensor(alphas_cumprod)) evaluates as (2.0 - alphas_cumprod) due to multiplication binding tighter than subtraction. The correct denominator is 2.0 * (1 - alphas_cumprod).
  • Strengthen NaN assertion from .all() to .any() so that partially-NaN weight tensors are caught immediately instead of silently passing validation.

Details

The VLBO (variational lower bound objective) loss weighting formula for the x0 parameterization branch is:

L_vlb ∝ sqrt(α̅_t) / (2 * (1 - α̅_t))

Without the parentheses around (1 - alphas_cumprod), the denominator becomes (2 - alphas_cumprod) which distorts the relative loss weights across timesteps. This matters for any training run that uses parameterization="x0" with the VLBO loss term.

The NaN guard assert not torch.isnan(x).all() only triggers when every element is NaN. Changing to .any() ensures the assertion catches even a single NaN, which is the expected safety check behavior.

Test plan

  • Verify lvlb_weights values match the expected VLBO formula for both eps and x0 parameterizations
  • Confirm no NaN values in lvlb_weights after initialization with standard schedules

…uard

The x0-parameterized lvlb_weights computation had a missing pair of
parentheses that caused incorrect evaluation:

  (2. * 1 - torch.Tensor(alphas_cumprod))

Due to operator precedence, this evaluates as (2.0 - alphas_cumprod)
instead of the intended 2.0 * (1 - alphas_cumprod). This produces
wrong loss weighting across timesteps when training with x0
parameterization.

Also tighten the NaN assertion from .all() to .any() so that
partially-NaN weight tensors are caught instead of silently passing.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant