diff --git a/compiler_opt/rl/compilation_runner.py b/compiler_opt/rl/compilation_runner.py index 15a6d32e..91aea13e 100644 --- a/compiler_opt/rl/compilation_runner.py +++ b/compiler_opt/rl/compilation_runner.py @@ -283,7 +283,8 @@ def is_priority_method(cls, method_name: str) -> bool: def __init__(self, clang_path: Optional[str] = None, launcher_path: Optional[str] = None, - moving_average_decay_rate: float = 1): + moving_average_decay_rate: float = 1, + compilation_timeout=None): """Initialization of CompilationRunner class. Args: @@ -294,7 +295,9 @@ def __init__(self, self._clang_path = clang_path self._launcher_path = launcher_path self._moving_average_decay_rate = moving_average_decay_rate - self._compilation_timeout = _COMPILATION_TIMEOUT.value + # Avoid reading the flag during the first interpretation of this module. + self._compilation_timeout = ( + compilation_timeout or _COMPILATION_TIMEOUT.value) self._cancellation_manager = WorkerCancellationManager() # re-allow the cancellation manager accept work. @@ -319,8 +322,6 @@ def collect_data( module_spec: a ModuleSpec. tf_policy_path: path to the tensorflow policy. reward_stat: reward stat of this module, None if unknown. - cancellation_token: a CancellationToken through which workers may be - signaled early termination Returns: A CompilationResult. In particular: