Skip to content

Commit 113c71e

Browse files
committed
fix pytype
1 parent 468adc3 commit 113c71e

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

compiler_opt/rl/compilation_runner.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,11 +230,12 @@ class CompilationRunnerStub(metaclass=abc.ABCMeta):
230230
"""The interface of a stub to CompilationRunner, for type checkers."""
231231

232232
@abc.abstractmethod
233-
def collect_results(self,
234-
module_spec: corpus.ModuleSpec,
235-
tf_policy_path: str,
236-
collect_default_result: bool,
237-
reward_only: bool = False) -> Tuple[Dict, Dict]:
233+
def collect_results(
234+
self,
235+
module_spec: corpus.ModuleSpec,
236+
tf_policy_path: str,
237+
collect_default_result: bool,
238+
reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]:
238239
raise NotImplementedError()
239240

240241
@abc.abstractmethod
@@ -290,11 +291,12 @@ def get_rewards(result: Dict) -> List[float]:
290291
return []
291292
return [v[1] for v in result.values()]
292293

293-
def collect_results(self,
294-
module_spec: corpus.ModuleSpec,
295-
tf_policy_path: str,
296-
collect_default_result: bool,
297-
reward_only: bool = False) -> Tuple[Dict, Dict]:
294+
def collect_results(
295+
self,
296+
module_spec: corpus.ModuleSpec,
297+
tf_policy_path: str,
298+
collect_default_result: bool,
299+
reward_only: bool = False) -> Tuple[Optional[Dict], Optional[Dict]]:
298300
"""Collect data for the given IR file and policy.
299301
300302
Args:

0 commit comments

Comments
 (0)