File tree Expand file tree Collapse file tree 1 file changed +23
-1
lines changed
src/brevitas/core/bit_width Expand file tree Collapse file tree 1 file changed +23
-1
lines changed Original file line number Diff line number Diff line change @@ -81,10 +81,32 @@ class StaticExponentBias(torch.nn.Module):
8181 a checkpoint but will be properly handled during device transfers and dtype conversions.
8282 """
8383
84- def __init__ (self , exponent_bias , device = None , dtype = None ):
84+ def __init__ (
85+ self , exponent_bias : float , device : torch .device = None , dtype : torch .dtype = None ):
8586 super ().__init__ ()
8687 self .exponent_bias = StatelessBuffer (
8788 torch .tensor (float (exponent_bias ), device = device , dtype = dtype ))
8889
8990 def forward (self ):
9091 return self .exponent_bias ()
92+
93+
94+ class ComputeExponentBias (torch .nn .Module ):
95+ """
96+ Module that returns a runtime-computed exponent bias value.
97+
98+ Args:
99+ exponent_bit_width_impl: Module that returns the exponent bit width
100+
101+ Examples:
102+ >>> exp_bias = ComputeExponentBias(4.)
103+ >>> exp_bias()
104+ tensor(7.)
105+ """
106+
107+ def __init__ (self , exponent_bit_width_impl : torch .nn .Module ):
108+ super ().__init__ ()
109+ self .exponent_bit_width_impl = exponent_bit_width_impl
110+
111+ def forward (self ):
112+ return 2 ** (self .exponent_bit_width_impl () - 1 ) - 1
You can’t perform that action at this time.
0 commit comments