Skip to content

Commit c20c03c

Browse files
Fix corner cases global shift gen
In some corner cases, the global shift factor was generated as a number < 0 (down to -inf...). This makes no sense, so now the global shift must be 0 at a minimum.
1 parent 13dd71f commit c20c03c

File tree

1 file changed

+5
-1
lines changed

1 file changed

+5
-1
lines changed

test/NnxTestClasses.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,11 @@ def _calculate_global_shift(
213213
"""Calculate global shift so that the output values are in the range of out_type"""
214214
s = tensor.type(torch.float64).std()
215215
target_s = 2 ** (out_type._bits - 1)
216-
return torch.ceil(torch.log2(s / target_s)).type(torch.int32)
216+
shift = torch.ceil(torch.log2(s / target_s)).type(torch.int32)
217+
if shift < 1:
218+
return torch.zeros((1,)).type(torch.int32)
219+
else:
220+
return shift
217221

218222
@staticmethod
219223
def _random_data(_type: IntegerType, shape: Tuple, extremes: Tuple = None):

0 commit comments

Comments
 (0)