Skip to content

Commit 3a5b206

Browse files
Move get_rewards to BlackboxEvaluator base class
Previously this function was implemented by subclasses. Move it to the base class given everything we have planned so far (TraceBlackboxEvaluator) will use the same implementation. Reviewers: mtrofin Reviewed By: mtrofin Pull Request: #418
1 parent 9915a6d commit 3a5b206

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

compiler_opt/es/blackbox_evaluator.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,18 @@ def get_results(
4545
def set_baseline(self) -> None:
4646
raise NotImplementedError()
4747

48-
@abc.abstractmethod
4948
def get_rewards(
5049
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
51-
raise NotImplementedError()
50+
rewards = [None] * len(results)
51+
52+
for i in range(len(results)):
53+
if not results[i].exception():
54+
rewards[i] = results[i].result()
55+
else:
56+
logging.info('Error retrieving result from future: %s',
57+
str(results[i].exception()))
58+
59+
return rewards
5260

5361

5462
@gin.configurable
@@ -95,16 +103,3 @@ def get_results(
95103

96104
def set_baseline(self) -> None:
97105
pass
98-
99-
def get_rewards(
100-
self, results: List[concurrent.futures.Future]) -> List[Optional[float]]:
101-
rewards = [None] * len(results)
102-
103-
for i in range(len(results)):
104-
if not results[i].exception():
105-
rewards[i] = results[i].result()
106-
else:
107-
logging.info('Error retrieving result from future: %s',
108-
str(results[i].exception()))
109-
110-
return rewards

compiler_opt/es/blackbox_evaluator_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def test_sampling_get_results(self):
4141
self.assertSequenceAlmostEqual([result.result() for result in results],
4242
[1.0, 1.0, 1.0])
4343

44-
def test_sampling_get_rewards(self):
44+
def test_get_rewards(self):
4545
f1 = concurrent.futures.Future()
4646
f1.set_exception(None)
4747
f2 = concurrent.futures.Future()

0 commit comments

Comments
 (0)