Feat (core/float): better max mantissa computation#1391
Feat (core/float): better max mantissa computation#1391Giuseppe5 merged 7 commits intoXilinx:devfrom
Conversation
|
Not sure why your tests are failing - I even sanity checked that your fix makes sense on my end: import torch
def cmm1(mm):
return torch.sum((2. ** torch.arange(0, -1. * mm - 1., -1.)))
def cmm2(mm):
return 2 * (1 - 2 ** (-mm - 1))
for i in range(0, 20):
mm = float(i)
assert cmm1(mm) == cmm2(mm)^Works with no error. |
nickfraser
left a comment
There was a problem hiding this comment.
A few comments. Not sure why this change is causing failing tests - it seems pretty benign to me!
Please fix the tests though 🙏
| def __init__(self, value): | ||
| super().__init__() | ||
| self.value = torch.tensor(value) | ||
| self.value = torch.tensor(float(value)) |
There was a problem hiding this comment.
Not obvious why this change is necessary. Worth a comment?
There was a problem hiding this comment.
And I had to add another one.
We need to make sure bitwidth is a float, so that the max mantissa computation is a float, otherwise it gets rounded to an int.
There was a problem hiding this comment.
Got it, but I meant a code comment ;)
| # scale inp manually | ||
| scaled_inp = inp / scale | ||
| max_mantissa = compute_max_mantissa(torch.tensor(mantissa_bit_width)) | ||
| max_mantissa = compute_max_mantissa(torch.tensor(float(mantissa_bit_width))) |
There was a problem hiding this comment.
Same here - not obvious why this change is necessary. Worth a comment?
Reason for this PR
Small refactor to improve computation of max available mantissa given a certain mantissa bit width.
This avoids data-dependent for loop, which has two main benefits:
mantissa_bit_widththrough the functionThe last point seems particularly relevant. From a custom training script, the following times have been observed:
Changes Made in this PR
Testing Summary
Risk Highlight
Checklist
devbranch.