Skip to content

Commit c9d56b6

Browse files
committed
format negative scientific notation correctly
1 parent 193075d commit c9d56b6

File tree

2 files changed

+54
-7
lines changed

2 files changed

+54
-7
lines changed

bayes_opt/logger.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,38 @@ def _format_number(self, x: float) -> str:
7575
-------
7676
A stringified, formatted version of `x`.
7777
"""
78-
if isinstance(x, int):
79-
s = f"{x:<{self._default_cell_size}}"
78+
if abs(x) > 1e7 - 1:
79+
s = f"{x:.5e}"
8080
else:
81-
s = f"{x:<{self._default_cell_size}.{self._default_precision}}"
81+
s = str(x)
8282

8383
if len(s) > self._default_cell_size:
84+
# Convert to str representation of scientific notation
85+
result = ""
86+
width = self._default_cell_size
87+
# Keep negative sign, exponent, and as many decimal places as possible
88+
if "-" in s:
89+
result += "-"
90+
width -= 1
91+
s = s[1:]
92+
if "e" in s:
93+
e_pos = s.find("e")
94+
end = s[e_pos:]
95+
width -= len(end)
8496
if "." in s:
85-
return s[: self._default_cell_size]
86-
return s[: self._default_cell_size - 3] + "..."
87-
return s
97+
dot_pos = s.find(".") + 1
98+
result += s[:dot_pos]
99+
width -= dot_pos
100+
if width > 0:
101+
result += s[dot_pos : dot_pos + width]
102+
else:
103+
result += s[:width]
104+
if "e" in s:
105+
result += end
106+
result = result.ljust(self._default_cell_size)
107+
else:
108+
result = s.ljust(self._default_cell_size)
109+
return result
88110

89111
def _format_bool(self, x: bool) -> str:
90112
"""Format a boolean.

tests/test_logger.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,37 @@ def test_format_number():
6767
long_int = 12345678901234
6868
formatted = logger._format_number(long_int)
6969
assert len(formatted) == logger._default_cell_size
70-
assert "..." in formatted
70+
assert formatted == "1.234e+13"
7171

7272
# Test long float truncation
7373
long_float = 1234.5678901234
7474
formatted = logger._format_number(long_float)
7575
assert len(formatted) == logger._default_cell_size
76+
assert formatted == "1234.5678"
77+
78+
# Test negative long float truncation
79+
long_float = -1234.5678901234
80+
formatted = logger._format_number(long_float)
81+
assert len(formatted) == logger._default_cell_size
82+
assert formatted == "-1234.567"
83+
84+
# Test scientific notation truncation
85+
sci_float = 12345678901234.5678901234
86+
formatted = logger._format_number(sci_float)
87+
assert len(formatted) == logger._default_cell_size
88+
assert formatted == "1.234e+13"
89+
90+
# Test negative scientific notation truncation
91+
sci_float = -12345678901234.5678901234
92+
formatted = logger._format_number(sci_float)
93+
assert len(formatted) == logger._default_cell_size
94+
assert formatted == "-1.23e+13"
95+
96+
# Test long scientific notation truncation
97+
sci_float = -12345678901234.534e132
98+
formatted = logger._format_number(sci_float)
99+
assert len(formatted) == logger._default_cell_size
100+
assert formatted == "-1.2e+145"
76101

77102

78103
def test_format_bool():

0 commit comments

Comments
 (0)