Skip to content

Commit 9ab099e

Browse files
committed
Add test
1 parent 113c211 commit 9ab099e

File tree

1 file changed

+62
-0
lines changed

1 file changed

+62
-0
lines changed

tests/test_function_evaluator.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import sys
23

34
import numpy as np
45
import matplotlib.pyplot as plt
@@ -25,6 +26,17 @@ def eval_func(input_params, output_params):
2526
plt.savefig("fig.png")
2627

2728

29+
def eval_func_logs(input_params, output_params):
30+
"""Evaluation function used for testing"""
31+
x0 = input_params["x0"]
32+
x1 = input_params["x1"]
33+
result = -(x0 + 10 * np.cos(x0)) * (x1 + 5 * np.cos(x1))
34+
output_params["f"] = result
35+
# write something to stdout and stderr
36+
print("This is a message to stdout.")
37+
print("This is a message to stderr.", file=sys.stderr)
38+
39+
2840
def test_function_evaluator():
2941
"""Test that an exploration runs successfully with a function evaluator."""
3042

@@ -93,5 +105,55 @@ def test_function_evaluator():
93105
diags.get_evaluation_dir_path(trial_index)
94106

95107

108+
def test_function_evaluator_with_logs():
109+
"""Test a function evaluator with redirected stdout and stderr."""
110+
111+
# Define variables and objectives.
112+
var1 = VaryingParameter("x0", -50.0, 5.0)
113+
var2 = VaryingParameter("x1", -5.0, 15.0)
114+
obj = Objective("f", minimize=False)
115+
116+
# Create generator.
117+
gen = RandomSamplingGenerator(
118+
varying_parameters=[var1, var2],
119+
objectives=[obj],
120+
)
121+
122+
# Create function evaluator.
123+
ev = FunctionEvaluator(
124+
function=eval_func_logs,
125+
redirect_logs_to_file=True,
126+
)
127+
128+
# Create exploration.
129+
exploration = Exploration(
130+
generator=gen,
131+
evaluator=ev,
132+
max_evals=10,
133+
sim_workers=2,
134+
exploration_dir_path="./tests_output/test_function_evaluator_logs",
135+
)
136+
137+
# Run exploration.
138+
exploration.run()
139+
140+
# Get diagnostics.
141+
diags = ExplorationDiagnostics(exploration)
142+
143+
# Check that the logs were redirected if specified.
144+
for trial_index in diags.history.trial_index:
145+
trial_dir = diags.get_evaluation_dir_path(trial_index)
146+
assert os.path.exists(os.path.join(trial_dir, "log.out"))
147+
assert os.path.exists(os.path.join(trial_dir, "log.err"))
148+
# Check contents of log files are as expected
149+
with open(os.path.join(trial_dir, "log.out"), "r") as f:
150+
log_out_content = f.read()
151+
assert "This is a message to stdout." in log_out_content
152+
with open(os.path.join(trial_dir, "log.err"), "r") as f:
153+
log_err_content = f.read()
154+
assert "This is a message to stderr." in log_err_content
155+
156+
96157
if __name__ == "__main__":
97158
test_function_evaluator()
159+
test_function_evaluator_with_logs()

0 commit comments

Comments
 (0)