Skip to content

Commit 9da6f22

Browse files
sdaultonfacebook-github-bot
authored andcommitted
add utilities for straight-through gradient estimators for discretization functions (#1515)
Summary: Pull Request resolved: #1515 see title Reviewed By: Balandat Differential Revision: D41475380 fbshipit-source-id: d5ba14b4f4e9c9fe51be73eec45ed03f625711f1
1 parent 92f0d1d commit 9da6f22

File tree

3 files changed

+146
-2
lines changed

3 files changed

+146
-2
lines changed

botorch/test_functions/multi_objective.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@
1111
1212
.. [Daulton2022]
1313
S. Daulton, S. Cakmak, M. Balandat, M. A. Osborne, E. Zhou, and E. Bakshy.
14-
Robust Multi-Objective Bayesian Optimization Under Input Noise. 2022.
14+
Robust Multi-Objective Bayesian Optimization Under Input Noise.
15+
Proceedings of the 39th International Conference on Machine Learning, 2022.
1516
1617
.. [Deb2005dtlz]
1718
K. Deb, L. Thiele, M. Laumanns, E. Zitzler, A. Abraham, L. Jain, and

botorch/utils/rounding.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,24 @@
44
# This source code is licensed under the MIT license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
r"""
8+
Discretization (rounding) functions for acquisition optimization.
9+
10+
References
11+
12+
.. [Daulton2022bopr]
13+
S. Daulton, X. Wan, D. Eriksson, M. Balandat, M. A. Osborne, E. Bakshy.
14+
Bayesian Optimization over Discrete and Mixed Spaces via Probabilistic
15+
Reparameterization. Advances in Neural Information Processing Systems
16+
35, 2022.
17+
"""
18+
719
from __future__ import annotations
820

921
import torch
1022
from torch import Tensor
23+
from torch.autograd import Function
24+
from torch.nn.functional import one_hot
1125

1226

1327
def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
@@ -27,3 +41,68 @@ def approximate_round(X: Tensor, tau: float = 1e-3) -> Tensor:
2741
scaled_remainder = (X - offset - 0.5) / tau
2842
rounding_component = (torch.tanh(scaled_remainder) + 1) / 2
2943
return offset + rounding_component
44+
45+
46+
class IdentitySTEFunction(Function):
47+
"""Base class for functions using straight through gradient estimators.
48+
49+
This class approximates the gradient with the identity function.
50+
"""
51+
52+
@staticmethod
53+
def backward(ctx, grad_output: Tensor) -> Tensor:
54+
r"""Use a straight-through estimator the gradient.
55+
56+
This uses the identity function.
57+
58+
Args:
59+
grad_output: A tensor of gradients.
60+
61+
Returns:
62+
The provided tensor.
63+
"""
64+
return grad_output
65+
66+
67+
class RoundSTE(IdentitySTEFunction):
68+
r"""Round the input tensor and use a straight-through gradient estimator.
69+
70+
[Daulton2022bopr]_ proposes using this in acquisition optimization.
71+
"""
72+
73+
@staticmethod
74+
def forward(ctx, X: Tensor) -> Tensor:
75+
r"""Round the input tensor element-wise.
76+
77+
Args:
78+
X: The tensor to be rounded.
79+
80+
Returns:
81+
A tensor where each element is rounded to the nearest integer.
82+
"""
83+
return X.round()
84+
85+
86+
class OneHotArgmaxSTE(IdentitySTEFunction):
87+
r"""Discretize a continuous relaxation of a one-hot encoded categorical.
88+
89+
This returns a one-hot encoded categorical and use a straight-through
90+
gradient estimator via an identity function.
91+
92+
[Daulton2022bopr]_ proposes using this in acquisition optimization.
93+
"""
94+
95+
@staticmethod
96+
def forward(ctx, X: Tensor) -> Tensor:
97+
r"""Discretize the input tensor.
98+
99+
This applies a argmax along the last dimensions of the input tensor
100+
and one-hot encodes the result.
101+
102+
Args:
103+
X: The tensor to be rounded.
104+
105+
Returns:
106+
A tensor where each element is rounded to the nearest integer.
107+
"""
108+
return one_hot(X.argmax(dim=-1), num_classes=X.shape[-1]).to(X)

test/utils/test_rounding.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,20 @@
66

77

88
import torch
9-
from botorch.utils.rounding import approximate_round
9+
from botorch.utils.rounding import (
10+
approximate_round,
11+
IdentitySTEFunction,
12+
OneHotArgmaxSTE,
13+
RoundSTE,
14+
)
1015
from botorch.utils.testing import BotorchTestCase
16+
from torch.nn.functional import one_hot
17+
18+
19+
class DummySTEFunction(IdentitySTEFunction):
20+
@staticmethod
21+
def forward(ctx, X):
22+
return 2 * X
1123

1224

1325
class TestApproximateRound(BotorchTestCase):
@@ -25,3 +37,55 @@ def test_approximate_round(self):
2537
X.requires_grad_(True)
2638
approximate_round(X).sum().backward()
2739
self.assertTrue((X.grad.abs() != 0).any())
40+
41+
42+
class TestIdentitySTEFunction(BotorchTestCase):
43+
def test_identity_ste(self):
44+
for dtype in (torch.float, torch.double):
45+
X = torch.rand(3, device=self.device, dtype=dtype)
46+
with self.assertRaises(NotImplementedError):
47+
IdentitySTEFunction.apply(X)
48+
X = X.requires_grad_(True)
49+
X_out = DummySTEFunction.apply(X)
50+
X_out.sum().backward()
51+
self.assertTrue(torch.equal(2 * X, X_out))
52+
self.assertTrue(torch.equal(X.grad, torch.ones_like(X)))
53+
54+
55+
class TestRoundSTE(BotorchTestCase):
56+
def test_round_ste(self):
57+
for dtype in (torch.float, torch.double):
58+
# sample uniformly from the interval [-2.5,2.5]
59+
X = torch.rand(5, 2, device=self.device, dtype=dtype) * 5 - 2.5
60+
expected_rounded_X = X.round()
61+
rounded_X = RoundSTE.apply(X)
62+
# test forward
63+
self.assertTrue(torch.equal(expected_rounded_X, rounded_X))
64+
# test backward
65+
X = X.requires_grad_(True)
66+
output = RoundSTE.apply(X)
67+
# sample some weights to checked that gradients are passed
68+
# as intended
69+
w = torch.rand_like(X)
70+
(w * output).sum().backward()
71+
self.assertTrue(torch.equal(w, X.grad))
72+
73+
74+
class TestOneHotArgmaxSTE(BotorchTestCase):
75+
def test_one_hot_argmax_ste(self):
76+
for dtype in (torch.float, torch.double):
77+
X = torch.rand(5, 4, device=self.device, dtype=dtype)
78+
expected_discretized_X = one_hot(
79+
X.argmax(dim=-1), num_classes=X.shape[-1]
80+
).to(X)
81+
discretized_X = OneHotArgmaxSTE.apply(X)
82+
# test forward
83+
self.assertTrue(torch.equal(expected_discretized_X, discretized_X))
84+
# test backward
85+
X = X.requires_grad_(True)
86+
output = OneHotArgmaxSTE.apply(X)
87+
# sample some weights to checked that gradients are passed
88+
# as intended
89+
w = torch.rand_like(X)
90+
(w * output).sum().backward()
91+
self.assertTrue(torch.equal(w, X.grad))

0 commit comments

Comments
 (0)