Skip to content

Commit 54db11f

Browse files
authored
Fix (core/scaling): fix dtype for int threshold (#1404)
1 parent 004479e commit 54db11f

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/brevitas/core/scaling/int_scaling.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import Optional
55
from typing import Union
66

7+
import torch
78
from torch import Tensor
89

910
import brevitas
@@ -14,10 +15,15 @@
1415
class IntScaling(brevitas.jit.ScriptModule):
1516
__constants__ = ['signed', 'narrow_range']
1617

17-
def __init__(self, narrow_range: bool, signed: Optional[bool] = None):
18+
def __init__(
19+
self,
20+
narrow_range: bool,
21+
signed: Optional[bool] = None,
22+
dtype: Optional[torch.dtype] = None):
1823
super(IntScaling, self).__init__()
1924
self.signed = signed
2025
self.narrow_range = narrow_range
26+
self.dtype = dtype
2127

2228
@brevitas.jit.script_method
2329
def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = None) -> Tensor:
@@ -28,9 +34,9 @@ def forward(self, bit_width: Tensor, signed: Optional[Union[bool, Tensor]] = Non
2834
# Workaround: required for compatibility with the JIT for PT=2.2.2
2935
is_signed = bool(is_signed.item()) if isinstance(is_signed, Tensor) else is_signed
3036
if is_signed:
31-
return -min_int(is_signed, self.narrow_range, bit_width)
37+
return -min_int(is_signed, self.narrow_range, bit_width).to(dtype=self.dtype)
3238
else:
33-
return max_int(is_signed, self.narrow_range, bit_width)
39+
return max_int(is_signed, self.narrow_range, bit_width).to(dtype=self.dtype)
3440

3541

3642
class PowerOfTwoIntScaling(brevitas.jit.ScriptModule):

0 commit comments

Comments
 (0)