Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions nflows/transforms/splines/rational_quadratic.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ def rational_quadratic_spline(
c = -input_delta * (inputs - input_cumheights)

discriminant = b.pow(2) - 4 * a * c

float_precision_mask = (torch.abs(discriminant)/(b.pow(2) + 1e-8)) < 1e-6
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add a comment explaining where this particular formula comes from, and define constants/parameters for the magic numbers?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a comment trying to explain the logic of the operation without being too verbose.
I am afraid the numbers are manly motivated by heuristics to manage numerical stability in our case.

discriminant[float_precision_mask] = 0

assert (discriminant >= 0).all()

root = (2 * c) / (-b - torch.sqrt(discriminant))
Expand Down