Skip to content

Commit c3e91d7

Browse files
committed
✨ Added sanitization in BatchFile template rendering
xtl.jobs.batchfiles:BatchTemplate - Added `__MAGIC_VARS__` dictionary to store 'magic' context variables (name is XTL*) that get dynamically computed when substituted xtl.jobs.batchfiles:BatchFile - The context passed to `BatchTemplate.substitute` is now first sanitized (e.g. escaping of paths) - Magic variables are now expanded and passed along with the context
1 parent 497cefe commit c3e91d7

File tree

1 file changed

+56
-4
lines changed

1 file changed

+56
-4
lines changed

src/xtl/jobs/batchfiles.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
from __future__ import annotations
22
import asyncio
3+
from datetime import datetime
4+
import shlex
35
from enum import Enum
46
from pathlib import Path
57
from string import Template
6-
from typing import Iterable, TYPE_CHECKING
8+
from typing import Callable, Iterable, TYPE_CHECKING
79

810
if TYPE_CHECKING:
911
from xtl.config.settings import DependencySettings
1012
from xtl.jobs.sites import ComputeSiteType, LocalSite, SchedulerSite
1113
from xtl.jobs.config2 import BatchJobConfig
14+
from xtl import version
1215
from xtl.common.compatibility import PY310_OR_LESS
1316
from xtl.common.os import FilePermissions
1417
from xtl.jobs.shells import Shell, ShellType, DefaultShell
@@ -62,6 +65,12 @@ class BatchTemplate(Template):
6265
)
6366
'''
6467

68+
__MAGIC_VARS__: dict[str, Callable] = {
69+
'XTL_COMMENT': lambda **kwargs: kwargs.get('comment_char', ''),
70+
'XTL_NL': lambda **kwargs: kwargs.get('new_line_char', ''),
71+
'XTL_DOCSTRING': lambda **kwargs: f'Generated by xtl {version} on {datetime.now().isoformat()}',
72+
}
73+
6574

6675
class BatchFile:
6776

@@ -285,8 +294,45 @@ async def cancel(self):
285294
"""
286295
await self.compute_site.cancel_batch(self)
287296

288-
@staticmethod
289-
def _render_template(template: str, context: dict) -> str:
297+
def _magic_context(self) -> dict:
298+
"""
299+
Returns a dictionary of "magic" context variables that can be used in batch file
300+
templates. These variables are automatically generated and can be accessed in
301+
templates using their keys (e.g., `__XTL_DOCSTRING__`).
302+
"""
303+
kwargs = {
304+
'new_line_char': self.shell.new_line_char,
305+
'comment_char': self.shell.comment_char,
306+
}
307+
return {var: func(**kwargs) for var, func in BatchTemplate.__MAGIC_VARS__.items()}
308+
309+
def _sanitize_context(self, context: dict) -> dict:
310+
"""
311+
Sanitize the context dictionary for safe insertion into batch file templates.
312+
"""
313+
sanitized = {}
314+
for key, value in context.items():
315+
match value:
316+
case Path():
317+
path_str = str(value)
318+
if self.shell.is_posix:
319+
# Use shlex.quote to properly escape the path for POSIX shells
320+
sanitized[key] = shlex.quote(path_str)
321+
elif self.shell == Shell.CMD: # CMD requires double quotes for escaping
322+
# Escape internal quotes by doubling them
323+
path_str = path_str.replace('"', '""')
324+
sanitized[key] = f'"{path_str}"'
325+
elif self.shell == Shell.POWERSHELL: # PWSH requires single quotes for escaping
326+
# Escape internal quotes by doubling them
327+
path_str = path_str.replace('\'', '\'\'')
328+
sanitized[key] = f'\'{path_str}\''
329+
else:
330+
raise ValueError(f'Unsupported shell for path sanitization: {self.shell}')
331+
case _:
332+
sanitized[key] = value
333+
return sanitized
334+
335+
def _render_template(self, template: str, context: dict = None) -> str:
290336
"""
291337
Render a batch file template with the given context.
292338
@@ -298,7 +344,13 @@ def _render_template(template: str, context: dict) -> str:
298344
:raises ValueError: If the `template` contains invalid placeholders.
299345
"""
300346
batch_template = BatchTemplate(template)
301-
return batch_template.substitute(context or {})
347+
try:
348+
magic = self._magic_context()
349+
sanitized = self._sanitize_context(context or {})
350+
return batch_template.substitute(sanitized | magic)
351+
except KeyError as e:
352+
missing_key = e.args[0]
353+
raise KeyError(f'Missing required key for {BatchFile.__name__} template: {missing_key}')
302354

303355
@classmethod
304356
def from_config(cls, config: BatchJobConfig | dict, context: dict = None) \

0 commit comments

Comments
 (0)