Skip to content

Commit c7ed6ab

Browse files
James Wilsonfacebook-github-bot
authored andcommitted
Enforce use of float64 in NdarrayOptimizationClosure (#1508)
Summary: Pull Request resolved: #1508 Reviewed By: esantorella Differential Revision: D41355824 fbshipit-source-id: 284fb4e9f0f8571a1b2905bbf5e3a5f4b4900298
1 parent 17b1bb7 commit c7ed6ab

File tree

5 files changed

+23
-15
lines changed

5 files changed

+23
-15
lines changed

botorch/optim/closures/core.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -110,11 +110,14 @@ def __init__(
110110
"""
111111
if get_state is None:
112112
# Note: Numpy supports copying data between ndarrays with different dtypes.
113-
# Hence, our default behavior need not coerce the ndarray represenations of
114-
# tensors in `parameters` to float64 when copying over data.
113+
# Hence, our default behavior need not coerce the ndarray representations
114+
# of tensors in `parameters` to float64 when copying over data.
115115
_as_array = as_ndarray if as_array is None else as_array
116116
get_state = partial(
117-
get_tensors_as_ndarray_1d, parameters, as_array=_as_array
117+
get_tensors_as_ndarray_1d,
118+
tensors=parameters,
119+
dtype=np_float64,
120+
as_array=_as_array,
118121
)
119122

120123
if as_array is None: # per the note, do this after resolving `get_state`
@@ -154,7 +157,7 @@ def __call__(
154157
grads[index : index + size] = self.as_array(grad.view(-1))
155158
index += size
156159
except RuntimeError as e:
157-
value, grads = _handle_numerical_errors(error=e, x=self.state)
160+
value, grads = _handle_numerical_errors(e, x=self.state, dtype=np_float64)
158161

159162
return value, grads
160163

@@ -174,9 +177,9 @@ def _get_gradient_ndarray(self, fill_value: Optional[float] = None) -> ndarray:
174177

175178
size = sum(param.numel() for param in self.parameters.values())
176179
array = (
177-
np_zeros(size)
180+
np_zeros(size, dtype=np_float64)
178181
if fill_value is None or fill_value == 0.0
179-
else np_full(size, fill_value)
182+
else np_full(size, fill_value, dtype=np_float64)
180183
)
181184
if self.persistent:
182185
self._gradient_ndarray = array

botorch/optim/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
from botorch.optim.closures import NdarrayOptimizationClosure
2020
from botorch.optim.utils import get_bounds_as_ndarray
21-
from numpy import asarray, ndarray
21+
from numpy import asarray, float64 as np_float64, ndarray
2222
from scipy.optimize import minimize
2323
from torch import Tensor
2424
from torch.optim.adam import Adam
@@ -105,7 +105,7 @@ def wrapped_callback(x: ndarray):
105105

106106
raw = minimize(
107107
wrapped_closure,
108-
wrapped_closure.state if x0 is None else x0,
108+
wrapped_closure.state if x0 is None else x0.astype(np_float64, copy=False),
109109
jac=True,
110110
bounds=bounds_np,
111111
method=method,

botorch/optim/utils/common.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _filter_kwargs(function: Callable, **kwargs: Any) -> Any:
3333

3434

3535
def _handle_numerical_errors(
36-
error: RuntimeError, x: np.ndarray
36+
error: RuntimeError, x: np.ndarray, dtype: Optional[np.dtype] = None
3737
) -> Tuple[np.ndarray, np.ndarray]:
3838
if isinstance(error, NotPSDError):
3939
raise error
@@ -43,7 +43,8 @@ def _handle_numerical_errors(
4343
or "singular" in error_message # old pytorch message
4444
or "input is not positive-definite" in error_message # since pytorch #63864
4545
):
46-
return np.full((), "nan", dtype=x.dtype), np.full_like(x, "nan")
46+
_dtype = x.dtype if dtype is None else dtype
47+
return np.full((), "nan", dtype=_dtype), np.full_like(x, "nan", dtype=_dtype)
4748
raise error # pragma: nocover
4849

4950

botorch/optim/utils/numpy_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def as_ndarray(
6666

6767
# Convert to ndarray and maybe cast to `dtype`
6868
out = out.numpy()
69-
return out if (dtype is None or dtype == out.dtype) else out.astype(dtype)
69+
return out.astype(dtype, copy=False)
7070

7171

7272
def get_tensors_as_ndarray_1d(

test/optim/utils/test_common.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,27 @@
1717

1818
class TestUtilsCommon(BotorchTestCase):
1919
def test_handle_numerical_errors(self):
20-
x = np.zeros(1)
20+
x = np.zeros(1, dtype=np.float64)
2121

2222
with self.assertRaisesRegex(NotPSDError, "foo"):
23-
_handle_numerical_errors(error=NotPSDError("foo"), x=x)
23+
_handle_numerical_errors(NotPSDError("foo"), x=x)
2424

2525
for error in (
2626
NanError(),
2727
RuntimeError("singular"),
2828
RuntimeError("input is not positive-definite"),
2929
):
30-
fake_loss, fake_grad = _handle_numerical_errors(error=error, x=x)
30+
fake_loss, fake_grad = _handle_numerical_errors(error, x=x)
3131
self.assertTrue(np.isnan(fake_loss))
3232
self.assertEqual(fake_grad.shape, x.shape)
3333
self.assertTrue(np.isnan(fake_grad).all())
3434

35+
fake_loss, fake_grad = _handle_numerical_errors(error, x=x, dtype=np.float32)
36+
self.assertEqual(np.float32, fake_loss.dtype)
37+
self.assertEqual(np.float32, fake_grad.dtype)
38+
3539
with self.assertRaisesRegex(RuntimeError, "foo"):
36-
_handle_numerical_errors(error=RuntimeError("foo"), x=x)
40+
_handle_numerical_errors(RuntimeError("foo"), x=x)
3741

3842
def test_warning_handler_template(self):
3943
with catch_warnings(record=True) as ws:

0 commit comments

Comments
 (0)