Skip to content

Commit aebedb9

Browse files
committed
fix: Replaced FloatTensor,IntTensor w/ Tensor
Signed-off-by: Brandon Groth <[email protected]>
1 parent 0173209 commit aebedb9

29 files changed

+685
-685
lines changed

fms_mo/quant/quantizers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2544,7 +2544,7 @@ def asymmetric_linear_quantization_params(
25442544
return scale, zero_point
25452545

25462546

2547-
def clamp(input_tensor: torch.FloatTensor, clamp_min, clamp_max, inplace=False):
2547+
def clamp(input_tensor: torch.Tensor, clamp_min, clamp_max, inplace=False):
25482548
"""
25492549
Returns:
25502550
Clamped Torch Tensor.

fms_mo/quant_refactor/base_quant.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ class Quantizer(torch.nn.Module):
175175

176176
def __init__(
177177
self,
178-
num_bits: torch.IntTensor,
178+
num_bits: torch.Tensor,
179179
dequantize: bool = True,
180180
qscheme: Qscheme = torch.per_tensor_symmetric,
181181
use_PT_native_Qfunc: bool = False,
@@ -189,7 +189,7 @@ def __init__(
189189
Init Quantizer Class
190190
191191
Args:
192-
num_bits (torch.IntTensor): Number of bit for quantization.
192+
num_bits (torch.Tensor): Number of bit for quantization.
193193
dequantize (bool, optional): Return dequantized or int tensor. Defaults to True.
194194
qscheme (Qscheme, optional): Quantization scheme.
195195
Defaults to Qscheme(unit="perT", symmetric=True).
@@ -257,7 +257,7 @@ def set_quantizer(self):
257257
f"Quantizer selection is not implemented for quantizer {self}"
258258
)
259259

260-
def forward(self, input_tensor: torch.FloatTensor):
260+
def forward(self, input_tensor: torch.Tensor):
261261
"""
262262
General forward() function for quantizer classes.
263263
@@ -266,13 +266,13 @@ def forward(self, input_tensor: torch.FloatTensor):
266266
To STE functions without calling calc_qparams()
267267
268268
Args:
269-
input_tensor (torch.FloatTensor): Tensor to be quantized.
269+
input_tensor (torch.Tensor): Tensor to be quantized.
270270
271271
Raises:
272272
ValueError: Single-sided qscheme has tensor min < 0.0
273273
274274
Returns:
275-
torch.FloatTensor: Dequantized or Quantized output tensor.
275+
torch.Tensor: Dequantized or Quantized output tensor.
276276
"""
277277
if self.qscheme.single_sided and input_tensor.min() < 0.0:
278278
raise ValueError(

fms_mo/quant_refactor/get_quantizer_new.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@
4141
def get_activation_quantizer_new(
4242
qa_mode: str = "PACT",
4343
nbits: int = 32,
44-
clip_val: torch.FloatTensor = None,
45-
clip_valn: torch.FloatTensor = None,
44+
clip_val: torch.Tensor = None,
45+
clip_valn: torch.Tensor = None,
4646
non_neg: bool = False,
4747
align_zero: bool = True, # pylint: disable=unused-argument
4848
extend_act_range: bool = False,
@@ -215,8 +215,8 @@ def get_activation_quantizer_new(
215215
def get_weight_quantizer_new(
216216
qw_mode: str = "SAWB+",
217217
nbits: int = 32,
218-
clip_val: torch.FloatTensor = None,
219-
clip_valn: torch.FloatTensor = None,
218+
clip_val: torch.Tensor = None,
219+
clip_valn: torch.Tensor = None,
220220
align_zero: bool = True,
221221
w_shape: torch.Size = None,
222222
recompute: bool = False, # pylint: disable=unused-argument

0 commit comments

Comments
 (0)