Skip to content

Feat (core/float): better max mantissa computation#1391

Merged
Giuseppe5 merged 7 commits intoXilinx:devfrom
Giuseppe5:fix_float_stuff
Oct 15, 2025
Merged

Feat (core/float): better max mantissa computation#1391
Giuseppe5 merged 7 commits intoXilinx:devfrom
Giuseppe5:fix_float_stuff

Conversation

@Giuseppe5
Copy link
Collaborator

@Giuseppe5 Giuseppe5 commented Oct 14, 2025

Reason for this PR

Small refactor to improve computation of max available mantissa given a certain mantissa bit width.
This avoids data-dependent for loop, which has two main benefits:

  • Allows for gradient flow of mantissa_bit_width through the function
  • Allows for fullgraph compile through the quantizers

The last point seems particularly relevant. From a custom training script, the following times have been observed:

  • Baseline - 700 imgs/s
  • Compile before this change (with graph breaks) - 1400 imgs/s
  • Compile with this change (no graph breaks) - 2000 imgs/s

Changes Made in this PR

Testing Summary

Risk Highlight

  • This PR includes code from another work (please detail).
  • This PR contains API-breaking changes.
  • This PR depends on work in another PR (please provide links/details).
  • This PR introduces new dependencies (please detail).
  • There are coverage gaps not covered by tests.
  • Documentation updates required in subsequent PR.

Checklist

  • Code comments added to any hard-to-understand areas, if applicable.
  • Changes generate no new warnings.
  • Updated any relevant tests, if applicable.
  • No conflicts with destination dev branch.
  • I reviewed my own code changes.
  • Initial CI/CD passing.
  • 1+ reviews given, and any review issues addressed and approved.
  • Post-review full CI/CD passing.

@Giuseppe5 Giuseppe5 self-assigned this Oct 14, 2025
@nickfraser
Copy link
Collaborator

Not sure why your tests are failing - I even sanity checked that your fix makes sense on my end:

import torch

def cmm1(mm):
    return torch.sum((2. ** torch.arange(0, -1. * mm - 1., -1.)))

def cmm2(mm):
    return 2 * (1 - 2 ** (-mm - 1))

for i in range(0, 20):
    mm = float(i)
    assert cmm1(mm) == cmm2(mm)

^Works with no error.

Copy link
Collaborator

@nickfraser nickfraser left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few comments. Not sure why this change is causing failing tests - it seems pretty benign to me!

Please fix the tests though 🙏

def __init__(self, value):
super().__init__()
self.value = torch.tensor(value)
self.value = torch.tensor(float(value))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not obvious why this change is necessary. Worth a comment?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And I had to add another one.

We need to make sure bitwidth is a float, so that the max mantissa computation is a float, otherwise it gets rounded to an int.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it, but I meant a code comment ;)

# scale inp manually
scaled_inp = inp / scale
max_mantissa = compute_max_mantissa(torch.tensor(mantissa_bit_width))
max_mantissa = compute_max_mantissa(torch.tensor(float(mantissa_bit_width)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here - not obvious why this change is necessary. Worth a comment?

@Giuseppe5 Giuseppe5 requested a review from nickfraser October 15, 2025 09:17
@Giuseppe5 Giuseppe5 merged commit b36f0ac into Xilinx:dev Oct 15, 2025
657 of 662 checks passed
@Giuseppe5 Giuseppe5 deleted the fix_float_stuff branch October 15, 2025 12:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants