@@ -82,7 +82,14 @@ def __init__(self, config_path=None, config=None):
8282
8383 for key in FUNCTION_APPROXIMATION_DEFAULTS :
8484 if key not in self .config :
85- self .config [key ] = FUNCTION_APPROXIMATION_DEFAULTS [key ]
85+ if key == "observation_space_args" and "config_space" in self .config :
86+ obs_length = 1 + len (self .config ["config_space" ]) * 4
87+ self .config [key ] = [
88+ np .array ([- np .inf for _ in range (obs_length )]),
89+ np .array ([np .inf for _ in range (obs_length )]),
90+ ]
91+ else :
92+ self .config [key ] = FUNCTION_APPROXIMATION_DEFAULTS [key ]
8693
8794 def get_environment (self ):
8895 """Return Function Approximation env with current configuration.
@@ -197,6 +204,10 @@ def get_benchmark(self, dimension=None, seed=0):
197204 "Slope (dimension 1)" ,
198205 "Action" ,
199206 ]
207+ self .config .observation_space_args = [
208+ np .array ([- np .inf for _ in range (4 )]),
209+ np .array ([np .inf for _ in range (4 )]),
210+ ]
200211 if dimension == 2 :
201212 self .config .instance_set_path = "sigmoid_2D3M_train.csv"
202213 self .config .test_set_path = "sigmoid_2D3M_test.csv"
@@ -220,6 +231,10 @@ def get_benchmark(self, dimension=None, seed=0):
220231 "Action dim 1" ,
221232 "Action dim 2" ,
222233 ]
234+ self .config .observation_space_args = [
235+ np .array ([- np .inf for _ in range (7 )]),
236+ np .array ([np .inf for _ in range (7 )]),
237+ ]
223238 if dimension == 3 :
224239 self .config .instance_set_path = "sigmoid_3D3M_train.csv"
225240 self .config .test_set_path = "sigmoid_3D3M_test.csv"
@@ -250,6 +265,10 @@ def get_benchmark(self, dimension=None, seed=0):
250265 "Action 2" ,
251266 "Action 3" ,
252267 ]
268+ self .config .observation_space_args = [
269+ np .array ([- np .inf for _ in range (10 )]),
270+ np .array ([np .inf for _ in range (10 )]),
271+ ]
253272 if dimension == 5 :
254273 self .config .instance_set_path = "sigmoid_5D3M_train.csv"
255274 self .config .test_set_path = "sigmoid_5D3M_test.csv"
@@ -294,6 +313,10 @@ def get_benchmark(self, dimension=None, seed=0):
294313 "Action 4" ,
295314 "Action 5" ,
296315 ]
316+ self .config .observation_space_args = [
317+ np .array ([- np .inf for _ in range (16 )]),
318+ np .array ([np .inf for _ in range (16 )]),
319+ ]
297320 self .config .seed = seed
298321 self .read_instance_set ()
299322 self .read_instance_set (test = True )
0 commit comments