@@ -39,8 +39,11 @@ def infer_reference_point(
3939) -> Tensor :
4040 r"""Get reference point for hypervolume computations.
4141
42- This sets the reference point to be `ref_point = nadir - 0.1 * range`
43- when there is no pareto_Y that is better than the reference point.
42+ This sets the reference point to be `ref_point = nadir - scale * range`
43+ when there is no `pareto_Y` that is better than `max_ref_point`.
44+ If there's `pareto_Y` better than `max_ref_point`, the reference point
45+ will be set to `max_ref_point - scale * range` if `scale_max_ref_point`
46+ is true and to `max_ref_point` otherwise.
4447
4548 [Ishibuchi2011]_ find 0.1 to be a robust multiplier for scaling the
4649 nadir point.
@@ -50,6 +53,9 @@ def infer_reference_point(
5053 Args:
5154 pareto_Y: A `n x m`-dim tensor of Pareto-optimal points.
5255 max_ref_point: A `m` dim tensor indicating the maximum reference point.
56+ Some elements can be NaN, except when `pareto_Y` is empty,
57+ in which case these dimensions will be treated as if no
58+ `max_ref_point` was provided and set to `nadir - scale * range`.
5359 scale: A multiplier used to scale back the reference point based on the
5460 range of each objective.
5561 scale_max_ref_point: A boolean indicating whether to apply scaling to
@@ -58,20 +64,28 @@ def infer_reference_point(
5864 Returns:
5965 A `m`-dim tensor containing the reference point.
6066 """
61-
6267 if pareto_Y .shape [0 ] == 0 :
6368 if max_ref_point is None :
6469 raise BotorchError ("Empty pareto set and no max ref point provided" )
70+ if max_ref_point .isnan ().any ():
71+ raise BotorchError ("Empty pareto set and max ref point includes NaN." )
6572 if scale_max_ref_point :
6673 return max_ref_point - scale * max_ref_point .abs ()
6774 return max_ref_point
6875 if max_ref_point is not None :
69- better_than_ref = (pareto_Y > max_ref_point ).all (dim = - 1 )
76+ non_nan_idx = ~ max_ref_point .isnan ()
77+ # Count all points exceeding non-NaN reference point as being better.
78+ better_than_ref = (pareto_Y [:, non_nan_idx ] > max_ref_point [non_nan_idx ]).all (
79+ dim = - 1
80+ )
7081 else :
71- better_than_ref = torch .full (
72- pareto_Y .shape [: 1 ], 1 , dtype = bool , device = pareto_Y .device
82+ non_nan_idx = torch .ones (
83+ pareto_Y .shape [- 1 ], dtype = torch . bool , device = pareto_Y .device
7384 )
74- if max_ref_point is not None and better_than_ref .any ():
85+ better_than_ref = torch .ones (
86+ pareto_Y .shape [:1 ], dtype = torch .bool , device = pareto_Y .device
87+ )
88+ if max_ref_point is not None and better_than_ref .any () and non_nan_idx .all ():
7589 Y_range = pareto_Y [better_than_ref ].max (dim = 0 ).values - max_ref_point
7690 if scale_max_ref_point :
7791 return max_ref_point - scale * Y_range
@@ -80,17 +94,28 @@ def infer_reference_point(
8094 # no points better than max_ref_point and only a single observation
8195 # subtract MIN_Y_RANGE to handle the case that pareto_Y is a singleton
8296 # with objective value of 0.
83- return (pareto_Y - scale * pareto_Y .abs ().clamp_min (MIN_Y_RANGE )).view (- 1 )
84- # no points better than max_ref_point and multiple observations
85- # make sure that each dimension of the nadir point is no greater than
86- # the max_ref_point
87- nadir = pareto_Y .min (dim = 0 ).values
88- if max_ref_point is not None :
89- nadir = torch .min (nadir , max_ref_point )
90- ideal = pareto_Y .max (dim = 0 ).values
91- # handle case where all values for one objective are the same
92- Y_range = (ideal - nadir ).clamp_min (MIN_Y_RANGE )
93- return nadir - scale * Y_range
97+ Y_range = pareto_Y .abs ().clamp_min (MIN_Y_RANGE ).view (- 1 )
98+ ref_point = pareto_Y .view (- 1 ) - scale * Y_range
99+ else :
100+ # no points better than max_ref_point and multiple observations
101+ # make sure that each dimension of the nadir point is no greater than
102+ # the max_ref_point
103+ nadir = pareto_Y .min (dim = 0 ).values
104+ if max_ref_point is not None :
105+ nadir [non_nan_idx ] = torch .min (
106+ nadir [non_nan_idx ], max_ref_point [non_nan_idx ]
107+ )
108+ ideal = pareto_Y .max (dim = 0 ).values
109+ # handle case where all values for one objective are the same
110+ Y_range = (ideal - nadir ).clamp_min (MIN_Y_RANGE )
111+ ref_point = nadir - scale * Y_range
112+ # Set not-nan indices - if any - to max_ref_point.
113+ if non_nan_idx .any () and not non_nan_idx .all () and better_than_ref .any ():
114+ if scale_max_ref_point :
115+ ref_point [non_nan_idx ] = (max_ref_point - scale * Y_range )[non_nan_idx ]
116+ else :
117+ ref_point [non_nan_idx ] = max_ref_point [non_nan_idx ]
118+ return ref_point
94119
95120
96121class Hypervolume :
0 commit comments