Skip to content

Commit 4a85790

Browse files
committed
TST: Test the unknown GPR optimizer statement
Test the unknown GPR optimizer statement. Declare the error message as a global variable so that it can be tested for exactness.
1 parent 3f684e1 commit 4a85790

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

src/nifreeze/model/gpr.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,9 @@
6262
SUPPORTED_OPTIMIZERS = set(CONFIGURABLE_OPTIONS.keys()) | {"fmin_l_bfgs_b"}
6363
"""A set of supported optimizers (automatically created)."""
6464

65+
UNKNOWN_OPTIMIZER_ERROR_MSG = "Unknown optimizer {optimizer}."
66+
"""Unknown optimizer error message."""
67+
6568

6669
class DiffusionGPR(GaussianProcessRegressor):
6770
r"""
@@ -252,7 +255,7 @@ def _constrained_optimization(
252255
if callable(self.optimizer):
253256
return self.optimizer(obj_func, initial_theta, bounds=bounds)
254257

255-
raise ValueError(f"Unknown optimizer {self.optimizer}.")
258+
raise ValueError(UNKNOWN_OPTIMIZER_ERROR_MSG.format(optimizer=self.optimizer))
256259

257260

258261
class ExponentialKriging(Kernel):

test/test_gpr.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,22 @@ def test_kernel(repodata, covariance):
292292

293293
K_predict = kernel(bvecs, bvecs[10:14, ...])
294294
assert K_predict.shape == (K.shape[0], 4)
295+
296+
297+
def test_unknown_optimizer():
298+
# Create a GPR with an optimizer string that is not supported
299+
optimizer = "bad-optimizer"
300+
gp = gpr.DiffusionGPR(optimizer=optimizer) # type: ignore
301+
302+
# A minimal objective function (will not be called by this test path)
303+
def obj_func(theta, eval_gradient):
304+
return 0.0
305+
306+
initial_theta = np.array([0.1])
307+
bounds = [(0.0, 1.0)]
308+
309+
# Expect the specific ValueError message including the optimizer name
310+
with pytest.raises(
311+
ValueError, match=gpr.UNKNOWN_OPTIMIZER_ERROR_MSG.format(optimizer=optimizer)
312+
):
313+
gp._constrained_optimization(obj_func, initial_theta, bounds)

0 commit comments

Comments
 (0)