Skip to content

Commit 242bca9

Browse files
committed
🚧 First go at LocalSite.execute_batch()
xtl.jobs.sites:LocalSite - First implementation of the .execute_batch() method, that handles termination and cancellation of processes pyproject.toml - Added aiofiles to the list of dependencies
1 parent af92b55 commit 242bca9

File tree

4 files changed

+144
-6
lines changed

4 files changed

+144
-6
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ classifiers = [
3030
]
3131
requires-python = ">=3.10"
3232
dependencies = [
33+
"aiofiles>=25.1.0",
3334
"click>=8.1.8",
3435
"defusedxml>=0.7.1",
3536
"distro>=1.9.0",

src/xtl/common/compatibility.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
PY310_OR_LESS: bool = (PY_VERS < (3, 11))
1717
"""Python version is 3.10 or less"""
1818
# Features missing in Python 3.10:
19-
# - ``enum.StrEnum`` (available in Python 3.11+)
19+
# - ``enum.StrEnum``
20+
# - ``asyncio.TaskGroup``
2021

2122
# TODO: Remove when dropping support for Python 3.11 (EOL: 10/2027)
2223
PY311_OR_LESS: bool = (PY_VERS < (3, 12))

src/xtl/jobs/batchfiles.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from xtl.common.os import FilePermissions
99
from xtl.config.settings import DependencySettings
1010
from xtl.jobs.shells import ShellType, DefaultShell
11-
from xtl.jobs.sites import ComputeSiteType, LocalSite
11+
from xtl.jobs.sites import ComputeSiteType, LocalSite, SchedulerSite
1212

1313
if PY310_OR_LESS:
1414
class StrEnum(str, Enum): ...
@@ -150,3 +150,21 @@ def save(self, overwrite: bool = False, update_permissions: bool = True):
150150

151151
self._saved = True
152152

153+
# Aliases for execution and cancellation
154+
# These methods provide an alternative API for interacting with batch files, rather
155+
# than going through the compute site directly.
156+
async def execute(self, schedule: bool = False):
157+
if not self._saved:
158+
raise RuntimeError('Batch file has not been saved yet. '
159+
'Call `save()` before `execute()`.')
160+
if schedule and isinstance(self.compute_site, SchedulerSite):
161+
await self.compute_site.schedule_batch(self)
162+
else:
163+
await self.compute_site.execute_batch(self)
164+
165+
async def cancel(self):
166+
await self.compute_site.cancel_batch(self)
167+
168+
# @classmethod
169+
# def from_config(cls, config: 'BatchConfig' | dict) -> BatchFile: ...
170+

src/xtl/jobs/sites.py

Lines changed: 122 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
from __future__ import annotations
2+
23
from abc import ABC, abstractmethod
4+
import asyncio
35
from enum import Enum
6+
from pathlib import Path
47
import re
58
import subprocess
69
from typing import Optional, Iterable, Sequence
710

11+
import aiofiles
12+
813
from xtl import settings
914
from xtl.config.settings import DependencySettings
1015
from xtl.common.compatibility import PY310_OR_LESS, XTL_COMPUTE_SITE
11-
from xtl.jobs.batchfiles import BatchFile
16+
from xtl.jobs.batchfiles import BatchFile, BatchFileStatus
1217
from xtl.jobs.policies import CommandPolicy, CommandPolicyType
1318
from xtl.jobs.shells import Shell, ShellType
1419
from xtl.logging import Logger
@@ -201,9 +206,6 @@ def prepare_content(self, content: str) -> str:
201206
def prepare_postamble(self) -> str:
202207
return self.policy.intercept_postamble('') if self.policy else ''
203208

204-
async def execute_batch(self, batch: BatchFile, **kwargs):
205-
raise NotImplementedError()
206-
207209
def check_dependencies(self, dependencies: Iterable[DependencySettings | str]
208210
| DependencySettings | str
209211
| None,
@@ -236,6 +238,122 @@ def check_dependencies(self, dependencies: Iterable[DependencySettings | str]
236238
logger.debug('All dependencies are met.')
237239
return True
238240

241+
async def execute_batch(self, batch: BatchFile, stdout: Path = None,
242+
stderr: Path = None, **kwargs):
243+
# Get command to execute the batch file
244+
cmd = batch.shell.get_execute_batch_command(batch.file)
245+
246+
if PY310_OR_LESS:
247+
# TODO: Update exception handling to asyncio.TaskGroup when we drop support
248+
# for Python 3.10
249+
pass
250+
251+
# Initialize stream logging tasks
252+
stdout_task, stderr_task = None, None
253+
try:
254+
# Start batch execution
255+
batch._process = await asyncio.create_subprocess_exec(
256+
*cmd,
257+
shell=False,
258+
stdout=asyncio.subprocess.PIPE,
259+
stderr=asyncio.subprocess.PIPE
260+
)
261+
batch._status = BatchFileStatus.RUNNING
262+
263+
# Start logging of streams
264+
stdout_task = asyncio.create_task(
265+
self._log_stream_to_file(batch.process.stdout, stdout),
266+
name='stdout_logger'
267+
)
268+
stderr_task = asyncio.create_task(
269+
self._log_stream_to_file(batch.process.stderr, stderr),
270+
name='stderr_logger'
271+
)
272+
273+
# Wait for all tasks to complete. This is where the main thread waits.
274+
await asyncio.gather(batch.process.wait(), stdout_task, stderr_task)
275+
276+
# Check if the batch was cancelled by .cancel_batch()
277+
if batch.status != BatchFileStatus.CANCELLED:
278+
batch._status = BatchFileStatus.COMPLETED
279+
except asyncio.CancelledError as e:
280+
# Raise cancellation up the stack after marking the batch as cancelled
281+
logger.error('Batch execution was cancelled by user')
282+
batch._status = BatchFileStatus.CANCELLED
283+
raise e
284+
except Exception as e:
285+
# Log any other exceptions during execution
286+
logger.error('Error executing batch file %{file}s: %{exc}s',
287+
{'file': batch.file, 'exc': str(e)})
288+
batch._status = BatchFileStatus.FAILED
289+
finally:
290+
# Terminate batch if still running
291+
if batch.process and batch.process.returncode is None:
292+
logger.debug('Terminating batch file %{file}s',
293+
{'file': batch.file})
294+
await self._terminate_batch(batch)
295+
296+
# Terminate logging tasks
297+
for task in (stdout_task, stderr_task):
298+
if task and not task.done():
299+
logger.debug('Cancelling async task %{name}s',
300+
{'name': task.get_name()})
301+
task.cancel()
302+
try:
303+
await task
304+
except asyncio.CancelledError:
305+
pass
306+
307+
async def cancel_batch(self, batch: BatchFile, **kwargs):
308+
if batch.status != BatchFileStatus.RUNNING:
309+
logger.warning('Cannot cancel batch file %{file}s; it is not running',
310+
{'file': batch.file})
311+
return
312+
logger.info('Cancelling batch file %{file}s', {'file': batch.file})
313+
batch._status = BatchFileStatus.CANCELLED
314+
await self._terminate_batch(batch)
315+
316+
@staticmethod
317+
async def _terminate_batch(batch: BatchFile, **kwargs):
318+
if batch.process and batch.process.returncode is None:
319+
# Request graceful termination of the process
320+
logger.debug('Sending SIGTERM to batch process PID %{pid}d',
321+
{'pid': batch.process.pid})
322+
batch.process.terminate()
323+
try:
324+
# Wait 10 s for a graceful shutdown
325+
await asyncio.wait_for(batch.process.wait(), timeout=10)
326+
except asyncio.TimeoutError:
327+
# Force kill the process
328+
logger.debug('Sending SIGKILL to batch process PID %{pid}d',
329+
{'pid': batch.process.pid})
330+
batch.process.kill()
331+
await batch.process.wait()
332+
333+
@staticmethod
334+
async def _log_stream_to_file(stream: asyncio.StreamReader,
335+
file: Optional[Path] = None):
336+
# Open the file for writing if provided
337+
f = await aiofiles.open(file, mode='wb') if file else None
338+
try:
339+
while True:
340+
# Read 8 KB of data from the stream
341+
chunk = await stream.read(8192)
342+
343+
# Break if no more data is available
344+
if not chunk:
345+
break
346+
347+
if f:
348+
# Write the chunk to the file
349+
await f.write(chunk)
350+
# Otherwise, the chunk is discarded, but still drained from the stream
351+
finally:
352+
# Close the file if it was opened
353+
if f:
354+
await f.close()
355+
356+
239357

240358
class ModulesSite(LocalSite):
241359
"""

0 commit comments

Comments
 (0)