Skip to content

Commit 7370224

Browse files
esantorellafacebook-github-bot
authored andcommitted
More informative errors for wrong index dtype (#1345)
Summary: ## Motivation This is fixing a very minor issue and I've probably already spent too much time on it. Unless anyone feels really strongly about this, I'd prefer to either have this quickly either accepted or rejected rather than spend a while iterating on revisions. In the past it was possible to use indexers with dtypes that torch does not accept as indexers via `equality_constraints` and `inequality_constraints`. This was never really intended behavior and stopped being supported in #1341 (discussed in #1225 ) . This PR makes errors more informative if someone does try to use the wrong dtypes, since the existing error message did not make clear where the error came from. I aslo refactored a test in test_initalizers.py. ### Have you read the [Contributing Guidelines on pull requests](https://github.com/pytorch/botorch/blob/main/CONTRIBUTING.md#pull-requests)? Yes Pull Request resolved: #1345 Test Plan: Unit tests for errors raised ## Related PRs #1341 Reviewed By: Balandat Differential Revision: D38627695 Pulled By: esantorella fbshipit-source-id: e9e4e917b79f81a36d74b79ff3b0f710667283cb
1 parent 324b7e2 commit 7370224

File tree

2 files changed

+100
-3
lines changed

2 files changed

+100
-3
lines changed

botorch/utils/sampling.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -888,12 +888,28 @@ def get_polytope_samples(
888888
"""
889889
# create tensors representing linear inequality constraints
890890
# of the form Ax >= b.
891+
# TODO: remove this error handling functionality in a few releases.
892+
# Context: BoTorch inadvertently supported indices with unusual dtypes.
893+
# This is now not supported.
894+
index_dtype_error = (
895+
"Normalizing {var_name} failed. Check that the first "
896+
"element of {var_name} is the correct dtype following "
897+
"the previous IndexError."
898+
)
891899
if inequality_constraints:
892900
# normalize_linear_constraints is called to solve this issue:
893901
# https://github.com/pytorch/botorch/issues/1225
902+
try:
903+
# non-standard dtypes used to be supported for indices in constraints;
904+
# this is no longer true
905+
constraints = normalize_linear_constraints(bounds, inequality_constraints)
906+
except IndexError as e:
907+
msg = index_dtype_error.format(var_name="`inequality_constraints`")
908+
raise ValueError(msg) from e
909+
894910
A, b = sparse_to_dense_constraints(
895911
d=bounds.shape[-1],
896-
constraints=normalize_linear_constraints(bounds, inequality_constraints),
912+
constraints=constraints,
897913
)
898914
# Note the inequality constraints are of the form Ax >= b,
899915
# but PolytopeSampler expects inequality constraints of the
@@ -902,11 +918,18 @@ def get_polytope_samples(
902918
else:
903919
dense_inequality_constraints = None
904920
if equality_constraints:
921+
try:
922+
# non-standard dtypes used to be supported for indices in constraints;
923+
# this is no longer true
924+
constraints = normalize_linear_constraints(bounds, equality_constraints)
925+
except IndexError as e:
926+
msg = index_dtype_error.format(var_name="`equality_constraints`")
927+
raise ValueError(msg) from e
928+
905929
# normalize_linear_constraints is called to solve this issue:
906930
# https://github.com/pytorch/botorch/issues/1225
907931
dense_equality_constraints = sparse_to_dense_constraints(
908-
d=bounds.shape[-1],
909-
constraints=normalize_linear_constraints(bounds, equality_constraints),
932+
d=bounds.shape[-1], constraints=constraints
910933
)
911934
else:
912935
dense_equality_constraints = None

test/utils/test_sampling.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,22 @@ def test_normalize_linear_constraints(self):
312312
expected_rhs = 0.5
313313
self.assertAlmostEqual(new_constraints[0][-1], expected_rhs)
314314

315+
def test_normalize_linear_constraints_wrong_dtype(self):
316+
for dtype in (torch.float, torch.double):
317+
with self.subTest(dtype=dtype):
318+
tkwargs = {"device": self.device, "dtype": dtype}
319+
constraints = [
320+
(
321+
torch.ones(3, dtype=torch.float, device=self.device),
322+
torch.ones(3, **tkwargs),
323+
1.0,
324+
)
325+
]
326+
bounds = torch.zeros(2, 4, **tkwargs)
327+
msg = "tensors used as indices must be long, byte or bool tensors"
328+
with self.assertRaises(IndexError, msg=msg):
329+
normalize_linear_constraints(bounds, constraints)
330+
315331
def test_find_interior_point(self):
316332
# basic problem: 1 <= x_1 <= 2, 2 <= x_2 <= 3
317333
A = np.concatenate([np.eye(2), -np.eye(2)], axis=0)
@@ -333,6 +349,64 @@ def test_find_interior_point(self):
333349
x = find_interior_point(A=A, b=b)
334350
self.assertAlmostEqual(x.item(), 5.0, places=4)
335351

352+
def test_get_polytope_samples_wrong_inequality_constraints_dtype(self):
353+
for dtype in (torch.float, torch.double):
354+
with self.subTest(dtype=dtype):
355+
tkwargs = {"device": self.device, "dtype": dtype}
356+
bounds = torch.zeros(2, 4, **tkwargs)
357+
inequality_constraints = [
358+
(
359+
torch.tensor([3], dtype=torch.float, device=self.device),
360+
torch.tensor([-4], **tkwargs),
361+
-3,
362+
)
363+
]
364+
365+
msg = (
366+
"Normalizing `inequality_constraints` failed. Check that the first "
367+
"element of `inequality_constraints` is the correct dtype following"
368+
" the previous IndexError."
369+
)
370+
msg_orig = "tensors used as indices must be long, byte or bool tensors"
371+
372+
with self.assertRaisesRegex(ValueError, msg), self.assertRaisesRegex(
373+
IndexError, msg_orig
374+
):
375+
get_polytope_samples(
376+
n=5,
377+
bounds=bounds,
378+
inequality_constraints=inequality_constraints,
379+
)
380+
381+
def test_get_polytope_samples_wrong_equality_constraints_dtype(self):
382+
for dtype in (torch.float, torch.double):
383+
with self.subTest(dtype=dtype):
384+
tkwargs = {"device": self.device, "dtype": dtype}
385+
bounds = torch.zeros(2, 4, **tkwargs)
386+
387+
equality_constraints = [
388+
(
389+
torch.tensor([0], dtype=torch.float, device=self.device),
390+
torch.tensor([1], **tkwargs),
391+
0.5,
392+
)
393+
]
394+
msg = (
395+
"Normalizing `equality_constraints` failed. Check that the first "
396+
"element of `equality_constraints` is the correct dtype following "
397+
"the previous IndexError."
398+
)
399+
msg_orig = "tensors used as indices must be long, byte or bool tensors"
400+
401+
with self.assertRaisesRegex(ValueError, msg), self.assertRaisesRegex(
402+
IndexError, msg_orig
403+
):
404+
get_polytope_samples(
405+
n=5,
406+
bounds=bounds,
407+
equality_constraints=equality_constraints,
408+
)
409+
336410
def test_get_polytope_samples(self):
337411
tkwargs = {"device": self.device}
338412
for dtype in (torch.float, torch.double):

0 commit comments

Comments
 (0)