|
1 | 1 | from __future__ import annotations |
| 2 | + |
2 | 3 | from abc import ABC, abstractmethod |
| 4 | +import asyncio |
3 | 5 | from enum import Enum |
| 6 | +from pathlib import Path |
4 | 7 | import re |
5 | 8 | import subprocess |
6 | 9 | from typing import Optional, Iterable, Sequence |
7 | 10 |
|
| 11 | +import aiofiles |
| 12 | + |
8 | 13 | from xtl import settings |
9 | 14 | from xtl.config.settings import DependencySettings |
10 | 15 | from xtl.common.compatibility import PY310_OR_LESS, XTL_COMPUTE_SITE |
11 | | -from xtl.jobs.batchfiles import BatchFile |
| 16 | +from xtl.jobs.batchfiles import BatchFile, BatchFileStatus |
12 | 17 | from xtl.jobs.policies import CommandPolicy, CommandPolicyType |
13 | 18 | from xtl.jobs.shells import Shell, ShellType |
14 | 19 | from xtl.logging import Logger |
@@ -201,9 +206,6 @@ def prepare_content(self, content: str) -> str: |
201 | 206 | def prepare_postamble(self) -> str: |
202 | 207 | return self.policy.intercept_postamble('') if self.policy else '' |
203 | 208 |
|
204 | | - async def execute_batch(self, batch: BatchFile, **kwargs): |
205 | | - raise NotImplementedError() |
206 | | - |
207 | 209 | def check_dependencies(self, dependencies: Iterable[DependencySettings | str] |
208 | 210 | | DependencySettings | str |
209 | 211 | | None, |
@@ -236,6 +238,122 @@ def check_dependencies(self, dependencies: Iterable[DependencySettings | str] |
236 | 238 | logger.debug('All dependencies are met.') |
237 | 239 | return True |
238 | 240 |
|
| 241 | + async def execute_batch(self, batch: BatchFile, stdout: Path = None, |
| 242 | + stderr: Path = None, **kwargs): |
| 243 | + # Get command to execute the batch file |
| 244 | + cmd = batch.shell.get_execute_batch_command(batch.file) |
| 245 | + |
| 246 | + if PY310_OR_LESS: |
| 247 | + # TODO: Update exception handling to asyncio.TaskGroup when we drop support |
| 248 | + # for Python 3.10 |
| 249 | + pass |
| 250 | + |
| 251 | + # Initialize stream logging tasks |
| 252 | + stdout_task, stderr_task = None, None |
| 253 | + try: |
| 254 | + # Start batch execution |
| 255 | + batch._process = await asyncio.create_subprocess_exec( |
| 256 | + *cmd, |
| 257 | + shell=False, |
| 258 | + stdout=asyncio.subprocess.PIPE, |
| 259 | + stderr=asyncio.subprocess.PIPE |
| 260 | + ) |
| 261 | + batch._status = BatchFileStatus.RUNNING |
| 262 | + |
| 263 | + # Start logging of streams |
| 264 | + stdout_task = asyncio.create_task( |
| 265 | + self._log_stream_to_file(batch.process.stdout, stdout), |
| 266 | + name='stdout_logger' |
| 267 | + ) |
| 268 | + stderr_task = asyncio.create_task( |
| 269 | + self._log_stream_to_file(batch.process.stderr, stderr), |
| 270 | + name='stderr_logger' |
| 271 | + ) |
| 272 | + |
| 273 | + # Wait for all tasks to complete. This is where the main thread waits. |
| 274 | + await asyncio.gather(batch.process.wait(), stdout_task, stderr_task) |
| 275 | + |
| 276 | + # Check if the batch was cancelled by .cancel_batch() |
| 277 | + if batch.status != BatchFileStatus.CANCELLED: |
| 278 | + batch._status = BatchFileStatus.COMPLETED |
| 279 | + except asyncio.CancelledError as e: |
| 280 | + # Raise cancellation up the stack after marking the batch as cancelled |
| 281 | + logger.error('Batch execution was cancelled by user') |
| 282 | + batch._status = BatchFileStatus.CANCELLED |
| 283 | + raise e |
| 284 | + except Exception as e: |
| 285 | + # Log any other exceptions during execution |
| 286 | + logger.error('Error executing batch file %{file}s: %{exc}s', |
| 287 | + {'file': batch.file, 'exc': str(e)}) |
| 288 | + batch._status = BatchFileStatus.FAILED |
| 289 | + finally: |
| 290 | + # Terminate batch if still running |
| 291 | + if batch.process and batch.process.returncode is None: |
| 292 | + logger.debug('Terminating batch file %{file}s', |
| 293 | + {'file': batch.file}) |
| 294 | + await self._terminate_batch(batch) |
| 295 | + |
| 296 | + # Terminate logging tasks |
| 297 | + for task in (stdout_task, stderr_task): |
| 298 | + if task and not task.done(): |
| 299 | + logger.debug('Cancelling async task %{name}s', |
| 300 | + {'name': task.get_name()}) |
| 301 | + task.cancel() |
| 302 | + try: |
| 303 | + await task |
| 304 | + except asyncio.CancelledError: |
| 305 | + pass |
| 306 | + |
| 307 | + async def cancel_batch(self, batch: BatchFile, **kwargs): |
| 308 | + if batch.status != BatchFileStatus.RUNNING: |
| 309 | + logger.warning('Cannot cancel batch file %{file}s; it is not running', |
| 310 | + {'file': batch.file}) |
| 311 | + return |
| 312 | + logger.info('Cancelling batch file %{file}s', {'file': batch.file}) |
| 313 | + batch._status = BatchFileStatus.CANCELLED |
| 314 | + await self._terminate_batch(batch) |
| 315 | + |
| 316 | + @staticmethod |
| 317 | + async def _terminate_batch(batch: BatchFile, **kwargs): |
| 318 | + if batch.process and batch.process.returncode is None: |
| 319 | + # Request graceful termination of the process |
| 320 | + logger.debug('Sending SIGTERM to batch process PID %{pid}d', |
| 321 | + {'pid': batch.process.pid}) |
| 322 | + batch.process.terminate() |
| 323 | + try: |
| 324 | + # Wait 10 s for a graceful shutdown |
| 325 | + await asyncio.wait_for(batch.process.wait(), timeout=10) |
| 326 | + except asyncio.TimeoutError: |
| 327 | + # Force kill the process |
| 328 | + logger.debug('Sending SIGKILL to batch process PID %{pid}d', |
| 329 | + {'pid': batch.process.pid}) |
| 330 | + batch.process.kill() |
| 331 | + await batch.process.wait() |
| 332 | + |
| 333 | + @staticmethod |
| 334 | + async def _log_stream_to_file(stream: asyncio.StreamReader, |
| 335 | + file: Optional[Path] = None): |
| 336 | + # Open the file for writing if provided |
| 337 | + f = await aiofiles.open(file, mode='wb') if file else None |
| 338 | + try: |
| 339 | + while True: |
| 340 | + # Read 8 KB of data from the stream |
| 341 | + chunk = await stream.read(8192) |
| 342 | + |
| 343 | + # Break if no more data is available |
| 344 | + if not chunk: |
| 345 | + break |
| 346 | + |
| 347 | + if f: |
| 348 | + # Write the chunk to the file |
| 349 | + await f.write(chunk) |
| 350 | + # Otherwise, the chunk is discarded, but still drained from the stream |
| 351 | + finally: |
| 352 | + # Close the file if it was opened |
| 353 | + if f: |
| 354 | + await f.close() |
| 355 | + |
| 356 | + |
239 | 357 |
|
240 | 358 | class ModulesSite(LocalSite): |
241 | 359 | """ |
|
0 commit comments