44from typing import Optional
55from typing import Union
66
7+ import torch
78from torch import Tensor
89
910import brevitas
1415class IntScaling (brevitas .jit .ScriptModule ):
1516 __constants__ = ['signed' , 'narrow_range' ]
1617
17- def __init__ (self , narrow_range : bool , signed : Optional [bool ] = None ):
18+ def __init__ (
19+ self ,
20+ narrow_range : bool ,
21+ signed : Optional [bool ] = None ,
22+ dtype : Optional [torch .dtype ] = None ):
1823 super (IntScaling , self ).__init__ ()
1924 self .signed = signed
2025 self .narrow_range = narrow_range
26+ self .dtype = dtype
2127
2228 @brevitas .jit .script_method
2329 def forward (self , bit_width : Tensor , signed : Optional [Union [bool , Tensor ]] = None ) -> Tensor :
@@ -28,9 +34,9 @@ def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = Non
2834 # Workaround: required for compatibility with the JIT for PT=2.2.2
2935 is_signed = bool (is_signed .item ()) if isinstance (is_signed , Tensor ) else is_signed
3036 if is_signed :
31- return - min_int (is_signed , self .narrow_range , bit_width )
37+ return - min_int (is_signed , self .narrow_range , bit_width ). to ( dtype = self . dtype )
3238 else :
33- return max_int (is_signed , self .narrow_range , bit_width )
39+ return max_int (is_signed , self .narrow_range , bit_width ). to ( dtype = self . dtype )
3440
3541
3642class PowerOfTwoIntScaling (brevitas .jit .ScriptModule ):
0 commit comments