|
4 | 4 |
|
5 | 5 | import copy |
6 | 6 | import logging |
| 7 | +import multiprocessing |
7 | 8 | import os |
| 9 | +import platform |
| 10 | +from concurrent.futures import ( |
| 11 | + FIRST_COMPLETED, |
| 12 | + ProcessPoolExecutor, |
| 13 | + wait, |
| 14 | +) |
8 | 15 | from dataclasses import dataclass |
9 | 16 | from typing import Callable, Dict, Optional, Union |
10 | 17 |
|
@@ -46,16 +53,20 @@ def _get_loader(self) -> Callable: |
46 | 53 | assert callable(loader) |
47 | 54 | return loader |
48 | 55 |
|
49 | | - def load_tasks(self, parameters, loaded_tasks, write_artifacts): |
| 56 | + def load_tasks(self, parameters, kind_dependencies_tasks, write_artifacts): |
| 57 | + logger.debug(f"Loading tasks for kind {self.name}") |
| 58 | + |
| 59 | + parameters = Parameters(**parameters) |
50 | 60 | loader = self._get_loader() |
51 | 61 | config = copy.deepcopy(self.config) |
52 | 62 |
|
53 | | - kind_dependencies = config.get("kind-dependencies", []) |
54 | | - kind_dependencies_tasks = { |
55 | | - task.label: task for task in loaded_tasks if task.kind in kind_dependencies |
56 | | - } |
57 | | - |
58 | | - inputs = loader(self.name, self.path, config, parameters, loaded_tasks) |
| 63 | + inputs = loader( |
| 64 | + self.name, |
| 65 | + self.path, |
| 66 | + config, |
| 67 | + parameters, |
| 68 | + list(kind_dependencies_tasks.values()), |
| 69 | + ) |
59 | 70 |
|
60 | 71 | transforms = TransformSequence() |
61 | 72 | for xform_path in config["transforms"]: |
@@ -89,6 +100,7 @@ def load_tasks(self, parameters, loaded_tasks, write_artifacts): |
89 | 100 | ) |
90 | 101 | for task_dict in transforms(trans_config, inputs) |
91 | 102 | ] |
| 103 | + logger.info(f"Generated {len(tasks)} tasks for kind {self.name}") |
92 | 104 | return tasks |
93 | 105 |
|
94 | 106 | @classmethod |
@@ -253,6 +265,103 @@ def _load_kinds(self, graph_config, target_kinds=None): |
253 | 265 | except KindNotFound: |
254 | 266 | continue |
255 | 267 |
|
| 268 | + def _load_tasks_serial(self, kinds, kind_graph, parameters): |
| 269 | + all_tasks = {} |
| 270 | + for kind_name in kind_graph.visit_postorder(): |
| 271 | + logger.debug(f"Loading tasks for kind {kind_name}") |
| 272 | + |
| 273 | + kind = kinds.get(kind_name) |
| 274 | + if not kind: |
| 275 | + message = f'Could not find the kind "{kind_name}"\nAvailable kinds:\n' |
| 276 | + for k in sorted(kinds): |
| 277 | + message += f' - "{k}"\n' |
| 278 | + raise Exception(message) |
| 279 | + |
| 280 | + try: |
| 281 | + new_tasks = kind.load_tasks( |
| 282 | + parameters, |
| 283 | + { |
| 284 | + k: t |
| 285 | + for k, t in all_tasks.items() |
| 286 | + if t.kind in kind.config.get("kind-dependencies", []) |
| 287 | + }, |
| 288 | + self._write_artifacts, |
| 289 | + ) |
| 290 | + except Exception: |
| 291 | + logger.exception(f"Error loading tasks for kind {kind_name}:") |
| 292 | + raise |
| 293 | + for task in new_tasks: |
| 294 | + if task.label in all_tasks: |
| 295 | + raise Exception("duplicate tasks with label " + task.label) |
| 296 | + all_tasks[task.label] = task |
| 297 | + |
| 298 | + return all_tasks |
| 299 | + |
| 300 | + def _load_tasks_parallel(self, kinds, kind_graph, parameters): |
| 301 | + all_tasks = {} |
| 302 | + futures_to_kind = {} |
| 303 | + futures = set() |
| 304 | + edges = set(kind_graph.edges) |
| 305 | + |
| 306 | + with ProcessPoolExecutor( |
| 307 | + mp_context=multiprocessing.get_context("fork") |
| 308 | + ) as executor: |
| 309 | + |
| 310 | + def submit_ready_kinds(): |
| 311 | + """Create the next batch of tasks for kinds without dependencies.""" |
| 312 | + nonlocal kinds, edges, futures |
| 313 | + loaded_tasks = all_tasks.copy() |
| 314 | + kinds_with_deps = {edge[0] for edge in edges} |
| 315 | + ready_kinds = ( |
| 316 | + set(kinds) - kinds_with_deps - set(futures_to_kind.values()) |
| 317 | + ) |
| 318 | + for name in ready_kinds: |
| 319 | + kind = kinds.get(name) |
| 320 | + if not kind: |
| 321 | + message = ( |
| 322 | + f'Could not find the kind "{name}"\nAvailable kinds:\n' |
| 323 | + ) |
| 324 | + for k in sorted(kinds): |
| 325 | + message += f' - "{k}"\n' |
| 326 | + raise Exception(message) |
| 327 | + |
| 328 | + future = executor.submit( |
| 329 | + kind.load_tasks, |
| 330 | + dict(parameters), |
| 331 | + { |
| 332 | + k: t |
| 333 | + for k, t in loaded_tasks.items() |
| 334 | + if t.kind in kind.config.get("kind-dependencies", []) |
| 335 | + }, |
| 336 | + self._write_artifacts, |
| 337 | + ) |
| 338 | + futures.add(future) |
| 339 | + futures_to_kind[future] = name |
| 340 | + |
| 341 | + submit_ready_kinds() |
| 342 | + while futures: |
| 343 | + done, _ = wait(futures, return_when=FIRST_COMPLETED) |
| 344 | + for future in done: |
| 345 | + if exc := future.exception(): |
| 346 | + executor.shutdown(wait=False, cancel_futures=True) |
| 347 | + raise exc |
| 348 | + kind = futures_to_kind.pop(future) |
| 349 | + futures.remove(future) |
| 350 | + |
| 351 | + for task in future.result(): |
| 352 | + if task.label in all_tasks: |
| 353 | + raise Exception("duplicate tasks with label " + task.label) |
| 354 | + all_tasks[task.label] = task |
| 355 | + |
| 356 | + # Update state for next batch of futures. |
| 357 | + del kinds[kind] |
| 358 | + edges = {e for e in edges if e[1] != kind} |
| 359 | + |
| 360 | + # Submit any newly unblocked kinds |
| 361 | + submit_ready_kinds() |
| 362 | + |
| 363 | + return all_tasks |
| 364 | + |
256 | 365 | def _run(self): |
257 | 366 | logger.info("Loading graph configuration.") |
258 | 367 | graph_config = load_graph_config(self.root_dir) |
@@ -307,31 +416,18 @@ def _run(self): |
307 | 416 | ) |
308 | 417 |
|
309 | 418 | logger.info("Generating full task set") |
310 | | - all_tasks = {} |
311 | | - for kind_name in kind_graph.visit_postorder(): |
312 | | - logger.debug(f"Loading tasks for kind {kind_name}") |
313 | | - |
314 | | - kind = kinds.get(kind_name) |
315 | | - if not kind: |
316 | | - message = f'Could not find the kind "{kind_name}"\nAvailable kinds:\n' |
317 | | - for k in sorted(kinds): |
318 | | - message += f' - "{k}"\n' |
319 | | - raise Exception(message) |
| 419 | + # Current parallel generation relies on multiprocessing, and forking. |
| 420 | + # This causes problems on Windows and macOS due to how new processes |
| 421 | + # are created there, and how doing so reinitializes global variables |
| 422 | + # that are modified earlier in graph generation, that doesn't get |
| 423 | + # redone in the new processes. Ideally this would be fixed, or we |
| 424 | + # would take another approach to parallel kind generation. In the |
| 425 | + # meantime, it's not supported outside of Linux. |
| 426 | + if platform.system() != "Linux" or os.environ.get("TASKGRAPH_SERIAL"): |
| 427 | + all_tasks = self._load_tasks_serial(kinds, kind_graph, parameters) |
| 428 | + else: |
| 429 | + all_tasks = self._load_tasks_parallel(kinds, kind_graph, parameters) |
320 | 430 |
|
321 | | - try: |
322 | | - new_tasks = kind.load_tasks( |
323 | | - parameters, |
324 | | - list(all_tasks.values()), |
325 | | - self._write_artifacts, |
326 | | - ) |
327 | | - except Exception: |
328 | | - logger.exception(f"Error loading tasks for kind {kind_name}:") |
329 | | - raise |
330 | | - for task in new_tasks: |
331 | | - if task.label in all_tasks: |
332 | | - raise Exception("duplicate tasks with label " + task.label) |
333 | | - all_tasks[task.label] = task |
334 | | - logger.info(f"Generated {len(new_tasks)} tasks for kind {kind_name}") |
335 | 431 | full_task_set = TaskGraph(all_tasks, Graph(frozenset(all_tasks), frozenset())) |
336 | 432 | yield self.verify("full_task_set", full_task_set, graph_config, parameters) |
337 | 433 |
|
|
0 commit comments