1- # Copyright 2020-2024 The Emukit Authors. All Rights Reserved.
1+ # Copyright 2020-2026 The Emukit Authors. All Rights Reserved.
22# SPDX-License-Identifier: Apache-2.0
33
44# Copyright 2018-2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
55# SPDX-License-Identifier: Apache-2.0
66
77
8- # Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
9- # SPDX-License-Identifier: Apache-2.0
10-
11-
128import numpy as np
13-
14- try :
15- from sobol_seq import i4_sobol_generate
16- except ImportError :
17- raise ImportError ("sobol_seq needs to be installed in order to use sobol design" )
9+ from scipy .stats import qmc
1810
1911from .. import ParameterSpace
2012from .base import InitialDesignBase
2315class SobolDesign (InitialDesignBase ):
2416 """
2517 Sobol experiment design.
26- Based on sobol_seq implementation. For further reference see https://github.com/naught101/sobol_seq
18+ Based on scipy implementation. For further reference see
19+ https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.qmc.Sobol.html
2720 """
2821
2922 def __init__ (self , parameter_space : ParameterSpace ) -> None :
@@ -40,12 +33,14 @@ def get_samples(self, point_count: int) -> np.ndarray:
4033 :return: A numpy array of generated samples, shape (point_count x space_dim)
4134 """
4235 bounds = self .parameter_space .get_bounds ()
43- lower_bound = np . asarray ( bounds )[:, 0 ]. reshape ( 1 , len (bounds ) )
44- upper_bound = np . asarray ( bounds )[:, 1 ]. reshape ( 1 , len ( bounds ))
45- diff = upper_bound - lower_bound
36+ d = len (bounds )
37+ lower_bounds = [ x [ 0 ] for x in bounds ]
38+ upper_bounds = [ x [ 1 ] for x in bounds ]
4639
47- X_design = np .dot (i4_sobol_generate (len (bounds ), point_count ), np .diag (diff [0 , :])) + lower_bound
40+ sampler = qmc .Sobol (d )
41+ samples = sampler .random (n = point_count )
42+ samples = qmc .scale (samples , lower_bounds , upper_bounds )
4843
49- samples = self .parameter_space .round (X_design )
44+ X_design = self .parameter_space .round (samples )
5045
51- return samples
46+ return X_design
0 commit comments