Skip to content

Commit fec0502

Browse files
stellaraccidentgiacs-epic
authored andcommitted
[iree.build] Wire up out of process concurrency. (iree-org#19291)
* Introduces an explicit thunk creation stage which gives a way to create a fully remotable object. * Reworks process concurrency to occupy a host thread in addition to a sub-process, which keeps the task concurrency accounting simple and makes errors propagate more easily. * Adds a test action for invoking a thunk out of process. * This is the boilerplate required while implementing a turbine AOT export action. Signed-off-by: Stella Laurenzo <[email protected]> Signed-off-by: Giacomo Serafini <[email protected]>
1 parent 4155e68 commit fec0502

File tree

5 files changed

+146
-13
lines changed

5 files changed

+146
-13
lines changed

compiler/bindings/python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,7 @@ SOURCES
268268
net_actions.py
269269
onnx_actions.py
270270
target_machine.py
271+
test_actions.py
271272
)
272273

273274
add_mlir_python_modules(IREECompilerBuildPythonModules

compiler/bindings/python/iree/build/executor.py

Lines changed: 46 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -276,8 +276,17 @@ def __str__(self) -> str:
276276
return self.value
277277

278278

279-
class BuildAction(BuildDependency, abc.ABC):
280-
"""An action that must be carried out."""
279+
class BuildAction(BuildDependency):
280+
"""An action that must be carried out.
281+
282+
This class is designed to be subclassed by concrete actions. In-process
283+
only actions should override `_invoke`, whereas those that can be executed
284+
out-of-process must override `_remotable_thunk`.
285+
286+
Note that even actions that are marked for `PROCESS` concurrency will
287+
run on a dedicated thread within the host process. Only the `_remotable_thunk`
288+
result will be scheduled out of process.
289+
"""
281290

282291
def __init__(
283292
self,
@@ -289,20 +298,43 @@ def __init__(
289298
):
290299
super().__init__(executor=executor, deps=deps)
291300
self.desc = desc
292-
self.concurrnecy = concurrency
301+
self.concurrency = concurrency
293302

294303
def __str__(self):
295304
return self.desc
296305

297306
def __repr__(self):
298307
return f"Action[{type(self).__name__}]('{self.desc}')"
299308

300-
def invoke(self):
301-
self._invoke()
309+
def invoke(self, scheduler: "Scheduler"):
310+
# Invoke is run within whatever in-process execution context was requested:
311+
# - On the scheduler thread for NONE
312+
# - On a worker thread for THREAD or PROCESS
313+
# For PROCESS concurrency, we have to create a compatible invocation
314+
# thunk, schedule that on the process pool and wait for it.
315+
if self.concurrency == ActionConcurrency.PROCESS:
316+
thunk = self._remotable_thunk()
317+
fut = scheduler.process_pool_executor.submit(thunk)
318+
fut.result()
319+
else:
320+
self._invoke()
302321

303-
@abc.abstractmethod
304322
def _invoke(self):
305-
...
323+
self._remotable_thunk()()
324+
325+
def _remotable_thunk(self) -> Callable[[], None]:
326+
"""Creates a remotable no-arg thunk that will execute this out of process.
327+
328+
This must return a no arg/result callable that can be pickled. While there
329+
are various ways to ensure this, here are a few guidelines:
330+
331+
* Must be a type/function defined at a module level.
332+
* Cannot be decorated.
333+
* Must only contain attributes with the same constraints.
334+
"""
335+
raise NotImplementedError(
336+
f"Action '{self}' does not implement remotable invocation"
337+
)
306338

307339

308340
class BuildContext(BuildDependency):
@@ -513,19 +545,20 @@ def _schedule_action(self, dep: BuildDependency):
513545
if isinstance(dep, BuildAction):
514546

515547
def invoke():
516-
dep.invoke()
548+
dep.invoke(self)
517549
return dep
518550

519551
print(f"Scheduling action: {dep}", file=self.stderr)
520-
if dep.concurrnecy == ActionConcurrency.NONE:
552+
if dep.concurrency == ActionConcurrency.NONE:
521553
invoke()
522-
elif dep.concurrnecy == ActionConcurrency.THREAD:
554+
elif (
555+
dep.concurrency == ActionConcurrency.THREAD
556+
or dep.concurrency == ActionConcurrency.PROCESS
557+
):
523558
dep.start(self.thread_pool_executor.submit(invoke))
524-
elif dep.concurrnecy == ActionConcurrency.PROCESS:
525-
dep.start(self.process_pool_executor.submit(invoke))
526559
else:
527560
raise AssertionError(
528-
f"Unhandled ActionConcurrency value: {dep.concurrnecy}"
561+
f"Unhandled ActionConcurrency value: {dep.concurrency}"
529562
)
530563
else:
531564
# Not schedulable. Just mark it as done.
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from typing import Callable
8+
from iree.build.executor import ActionConcurrency, BuildAction
9+
10+
11+
class _ThunkTrampoline:
12+
def __init__(self, thunk, args):
13+
self.thunk = thunk
14+
self.args = args
15+
16+
def __call__(self):
17+
self.thunk(*self.args)
18+
19+
20+
class ExecuteOutOfProcessThunkAction(BuildAction):
21+
"""Executes a callback thunk with arguments.
22+
23+
Both the thunk and args must be pickleable.
24+
"""
25+
26+
def __init__(self, thunk, args, concurrency=ActionConcurrency.PROCESS, **kwargs):
27+
super().__init__(concurrency=concurrency, **kwargs)
28+
self.trampoline = _ThunkTrampoline(thunk, args)
29+
30+
def _remotable_thunk(self) -> Callable[[], None]:
31+
return self.trampoline

compiler/bindings/python/test/build_api/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,10 @@ if(IREE_INPUT_TORCH)
1313
"mnist_builder_test.py"
1414
)
1515
endif()
16+
17+
iree_py_test(
18+
NAME
19+
concurrency_test
20+
SRCS
21+
"concurrency_test.py"
22+
)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import os
8+
from pathlib import Path
9+
import tempfile
10+
import unittest
11+
12+
from iree.build import *
13+
from iree.build.executor import BuildContext
14+
from iree.build.test_actions import ExecuteOutOfProcessThunkAction
15+
16+
17+
@entrypoint
18+
def write_out_of_process_pid():
19+
context = BuildContext.current()
20+
output_file = context.allocate_file("pid.txt")
21+
action = ExecuteOutOfProcessThunkAction(
22+
_write_pid_file,
23+
args=[output_file.get_fs_path()],
24+
desc="Writing pid file",
25+
executor=context.executor,
26+
)
27+
output_file.deps.add(action)
28+
return output_file
29+
30+
31+
def _write_pid_file(output_path: Path):
32+
pid = os.getpid()
33+
print(f"Running action out of process: pid={pid}")
34+
output_path.write_text(str(pid))
35+
36+
37+
class ConcurrencyTest(unittest.TestCase):
38+
def setUp(self):
39+
self._temp_dir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
40+
self._temp_dir.__enter__()
41+
self.output_path = Path(self._temp_dir.name)
42+
43+
def tearDown(self) -> None:
44+
self._temp_dir.__exit__(None, None, None)
45+
46+
def testProcessConcurrency(self):
47+
parent_pid = os.getpid()
48+
print(f"Testing out of process concurrency: pid={parent_pid}")
49+
iree_build_main(
50+
args=["write_out_of_process_pid", "--output-dir", str(self.output_path)]
51+
)
52+
pid_file = (
53+
self.output_path / "genfiles" / "write_out_of_process_pid" / "pid.txt"
54+
)
55+
child_pid = int(pid_file.read_text())
56+
print(f"Got child pid={child_pid}")
57+
self.assertNotEqual(parent_pid, child_pid)
58+
59+
60+
if __name__ == "__main__":
61+
unittest.main()

0 commit comments

Comments
 (0)