File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
src/brevitas/core/bit_width Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -81,7 +81,8 @@ class StaticExponentBias(torch.nn.Module):
8181 a checkpoint but will be properly handled during device transfers and dtype conversions.
8282 """
8383
84- def __init__ (self , exponent_bias , device = None , dtype = None ):
84+ def __init__ (
85+ self , exponent_bias : float , device : torch .device = None , dtype : torch .dtype = None ):
8586 super ().__init__ ()
8687 self .exponent_bias = StatelessBuffer (
8788 torch .tensor (float (exponent_bias ), device = device , dtype = dtype ))
@@ -103,7 +104,7 @@ class ComputeExponentBias(torch.nn.Module):
103104 tensor(7.)
104105 """
105106
106- def __init__ (self , exponent_bit_width_impl ):
107+ def __init__ (self , exponent_bit_width_impl : torch . nn . Module ):
107108 super ().__init__ ()
108109 self .exponent_bit_width_impl = exponent_bit_width_impl
109110
You can’t perform that action at this time.
0 commit comments