Skip to content

Commit 912f3fd

Browse files
authored
Feat (core): module for runtime computation of exp bias (#1418)
1 parent 60e5558 commit 912f3fd

File tree

1 file changed

+23
-1
lines changed

1 file changed

+23
-1
lines changed

src/brevitas/core/bit_width/float.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,32 @@ class StaticExponentBias(torch.nn.Module):
8181
a checkpoint but will be properly handled during device transfers and dtype conversions.
8282
"""
8383

84-
def __init__(self, exponent_bias, device=None, dtype=None):
84+
def __init__(
85+
self, exponent_bias: float, device: torch.device = None, dtype: torch.dtype = None):
8586
super().__init__()
8687
self.exponent_bias = StatelessBuffer(
8788
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
8889

8990
def forward(self):
9091
return self.exponent_bias()
92+
93+
94+
class ComputeExponentBias(torch.nn.Module):
95+
"""
96+
Module that returns a runtime-computed exponent bias value.
97+
98+
Args:
99+
exponent_bit_width_impl: Module that returns the exponent bit width
100+
101+
Examples:
102+
>>> exp_bias = ComputeExponentBias(4.)
103+
>>> exp_bias()
104+
tensor(7.)
105+
"""
106+
107+
def __init__(self, exponent_bit_width_impl: torch.nn.Module):
108+
super().__init__()
109+
self.exponent_bit_width_impl = exponent_bit_width_impl
110+
111+
def forward(self):
112+
return 2 ** (self.exponent_bit_width_impl() - 1) - 1

0 commit comments

Comments
 (0)