Skip to content

Commit b27ea37

Browse files
committed
typing
1 parent 6b44807 commit b27ea37

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/brevitas/core/bit_width/float.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ class StaticExponentBias(torch.nn.Module):
8181
a checkpoint but will be properly handled during device transfers and dtype conversions.
8282
"""
8383

84-
def __init__(self, exponent_bias, device=None, dtype=None):
84+
def __init__(
85+
self, exponent_bias: float, device: torch.device = None, dtype: torch.dtype = None):
8586
super().__init__()
8687
self.exponent_bias = StatelessBuffer(
8788
torch.tensor(float(exponent_bias), device=device, dtype=dtype))
@@ -103,7 +104,7 @@ class ComputeExponentBias(torch.nn.Module):
103104
tensor(7.)
104105
"""
105106

106-
def __init__(self, exponent_bit_width_impl):
107+
def __init__(self, exponent_bit_width_impl: torch.nn.Module):
107108
super().__init__()
108109
self.exponent_bit_width_impl = exponent_bit_width_impl
109110

0 commit comments

Comments
 (0)