4747from __future__ import annotations
4848
4949import math
50- from typing import List , Optional , Tuple
50+ from typing import List , Optional , Tuple , Union
5151
5252import torch
53+ from botorch .exceptions .errors import InputDataError
5354from botorch .test_functions .base import BaseTestProblem , ConstrainedBaseTestProblem
5455from botorch .test_functions .utils import round_nearest
5556from torch import Tensor
@@ -64,13 +65,15 @@ class SyntheticTestFunction(BaseTestProblem):
6465
6566 def __init__ (
6667 self ,
67- noise_std : Optional [ float ] = None ,
68+ noise_std : Union [ None , float , List [ float ] ] = None ,
6869 negate : bool = False ,
6970 bounds : Optional [List [Tuple [float , float ]]] = None ,
7071 ) -> None :
7172 r"""
7273 Args:
73- noise_std: Standard deviation of the observation noise.
74+ noise_std: Standard deviation of the observation noise. If a list is
75+ provided, specifies separate noise standard deviations for each
76+ objective in a multiobjective problem.
7477 negate: If True, negate the function.
7578 bounds: Custom bounds for the function specified as (lower, upper) pairs.
7679 """
@@ -802,7 +805,61 @@ def evaluate_true(self, X: Tensor) -> Tensor:
802805# ------------ Constrained synthetic test functions ----------- #
803806
804807
805- class ConstrainedGramacy (ConstrainedBaseTestProblem , SyntheticTestFunction ):
808+ class ConstrainedSyntheticTestFunction (
809+ ConstrainedBaseTestProblem , SyntheticTestFunction
810+ ):
811+ r"""Base class for constrained synthetic test functions."""
812+
813+ def __init__ (
814+ self ,
815+ noise_std : Union [None , float , List [float ]] = None ,
816+ constraint_noise_std : Union [None , float , List [float ]] = None ,
817+ negate : bool = False ,
818+ bounds : Optional [List [Tuple [float , float ]]] = None ,
819+ ) -> None :
820+ r"""
821+ Args:
822+ noise_std: Standard deviation of the observation noise. If a list is
823+ provided, specifies separate noise standard deviations for each
824+ objective in a multiobjective problem.
825+ constraint_noise_std: Standard deviation of the constraint noise.
826+ If a list is provided, specifies separate noise standard
827+ deviations for each constraint.
828+ negate: If True, negate the function.
829+ bounds: Custom bounds for the function specified as (lower, upper) pairs.
830+ """
831+ self .constraint_noise_std = self ._validate_constraint_noise (
832+ constraint_noise_std
833+ )
834+ SyntheticTestFunction .__init__ (
835+ self , noise_std = noise_std , negate = negate , bounds = bounds
836+ )
837+
838+ def _validate_constraint_noise (
839+ self , constraint_noise_std
840+ ) -> Union [None , float , List [float ]]:
841+ """
842+ Validates that constraint_noise_std has length equal to
843+ the number of constraints, if given as a list
844+
845+ Args:
846+ constraint_noise_std: Standard deviation of the constraint noise.
847+ If a list is provided, specifies separate noise standard
848+ deviations for each constraint.
849+ """
850+ if (
851+ isinstance (constraint_noise_std , list )
852+ and len (constraint_noise_std ) != self .num_constraints
853+ ):
854+ raise InputDataError (
855+ "If specified as a list, length of constraint_noise_std "
856+ f"({ len (constraint_noise_std )} ) must match the "
857+ f"number of constraints ({ self .num_constraints } )"
858+ )
859+ return constraint_noise_std
860+
861+
862+ class ConstrainedGramacy (ConstrainedSyntheticTestFunction ):
806863 r"""Constrained Gramacy test function.
807864
808865 This problem comes from [Gramacy2016]_. The problem is defined
@@ -835,31 +892,77 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
835892 return torch .cat ([- c1 , - c2 ], dim = - 1 )
836893
837894
838- class ConstrainedHartmann (Hartmann , ConstrainedBaseTestProblem ):
895+ class ConstrainedHartmann (Hartmann , ConstrainedSyntheticTestFunction ):
839896 r"""Constrained Hartmann test function.
840897
841898 This is a constrained version of the standard Hartmann test function that
842899 uses `||x||_2 <= 1` as the constraint. This problem comes from [Letham2019]_.
843900 """
844901 num_constraints = 1
845902
903+ def __init__ (
904+ self ,
905+ dim : int = 6 ,
906+ noise_std : Union [None , float ] = None ,
907+ constraint_noise_std : Union [None , float , List [float ]] = None ,
908+ negate : bool = False ,
909+ bounds : Optional [List [Tuple [float , float ]]] = None ,
910+ ) -> None :
911+ r"""
912+ Args:
913+ dim: The (input) dimension.
914+ noise_std: Standard deviation of the observation noise.
915+ constraint_noise_std: Standard deviation of the constraint noise.
916+ If a list is provided, specifies separate noise standard
917+ deviations for each constraint.
918+ negate: If True, negate the function.
919+ bounds: Custom bounds for the function specified as (lower, upper) pairs.
920+ """
921+ self ._validate_constraint_noise (constraint_noise_std )
922+ Hartmann .__init__ (
923+ self , dim = dim , noise_std = noise_std , negate = negate , bounds = bounds
924+ )
925+
846926 def evaluate_slack_true (self , X : Tensor ) -> Tensor :
847927 return - X .norm (dim = - 1 , keepdim = True ) + 1
848928
849929
850- class ConstrainedHartmannSmooth (Hartmann , ConstrainedBaseTestProblem ):
930+ class ConstrainedHartmannSmooth (Hartmann , ConstrainedSyntheticTestFunction ):
851931 r"""Smooth constrained Hartmann test function.
852932
853933 This is a constrained version of the standard Hartmann test function that
854934 uses `||x||_2^2 <= 1` as the constraint to obtain smoother constraint slack.
855935 """
856936 num_constraints = 1
857937
938+ def __init__ (
939+ self ,
940+ dim : int = 6 ,
941+ noise_std : Union [None , float ] = None ,
942+ constraint_noise_std : Union [None , float , List [float ]] = None ,
943+ negate : bool = False ,
944+ bounds : Optional [List [Tuple [float , float ]]] = None ,
945+ ) -> None :
946+ r"""
947+ Args:
948+ dim: The (input) dimension.
949+ noise_std: Standard deviation of the observation noise.
950+ constraint_noise_std: Standard deviation of the constraint noise.
951+ If a list is provided, specifies separate noise standard
952+ deviations for each constraint.
953+ negate: If True, negate the function.
954+ bounds: Custom bounds for the function specified as (lower, upper) pairs.
955+ """
956+ self ._validate_constraint_noise (constraint_noise_std )
957+ Hartmann .__init__ (
958+ self , dim = dim , noise_std = noise_std , negate = negate , bounds = bounds
959+ )
960+
858961 def evaluate_slack_true (self , X : Tensor ) -> Tensor :
859962 return - X .pow (2 ).sum (dim = - 1 , keepdim = True ) + 1
860963
861964
862- class PressureVessel (SyntheticTestFunction , ConstrainedBaseTestProblem ):
965+ class PressureVessel (ConstrainedSyntheticTestFunction ):
863966 r"""Pressure vessel design problem with constraints.
864967
865968 The four-dimensional pressure vessel design problem with four black-box
@@ -894,7 +997,7 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
894997 )
895998
896999
897- class WeldedBeamSO (SyntheticTestFunction , ConstrainedBaseTestProblem ):
1000+ class WeldedBeamSO (ConstrainedSyntheticTestFunction ):
8981001 r"""Welded beam design problem with constraints (single-outcome).
8991002
9001003 The four-dimensional welded beam design proble problem with six
@@ -950,7 +1053,7 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
9501053 return - torch .stack ([g1 , g2 , g3 , g4 , g5 , g6 ], dim = - 1 )
9511054
9521055
953- class TensionCompressionString (SyntheticTestFunction , ConstrainedBaseTestProblem ):
1056+ class TensionCompressionString (ConstrainedSyntheticTestFunction ):
9541057 r"""Tension compression string optimization problem with constraints.
9551058
9561059 The three-dimensional tension compression string optimization problem with
@@ -981,7 +1084,7 @@ def evaluate_slack_true(self, X: Tensor) -> Tensor:
9811084 return - constraints .clamp_max (100 )
9821085
9831086
984- class SpeedReducer (SyntheticTestFunction , ConstrainedBaseTestProblem ):
1087+ class SpeedReducer (ConstrainedSyntheticTestFunction ):
9851088 r"""Speed Reducer design problem with constraints.
9861089
9871090 The seven-dimensional speed reducer design problem with eleven black-box
0 commit comments