|
3 | 3 | import subprocess as sub |
4 | 4 | import sys |
5 | 5 | import os |
| 6 | +import re |
6 | 7 | import stat |
7 | 8 | from pathlib import Path |
| 9 | +from pprint import pprint |
8 | 10 | import shutil |
9 | 11 | import traceback |
10 | | -from typing import List, Optional, Dict |
11 | 12 | from datetime import datetime |
| 13 | +from typing import List, Optional, Dict |
| 14 | +import requests |
| 15 | +from sophios.wic_types import Json |
12 | 16 |
|
13 | 17 | try: |
14 | 18 | import cwltool.main |
|
30 | 34 | from .plugins import logging_filters |
31 | 35 |
|
32 | 36 |
|
| 37 | +def sanitize_env_vars(env_vars: Dict[str, str]) -> Dict[str, str]: |
| 38 | + """ |
| 39 | + Sanitizes a dictionary of user-defined environment variables, assuming all |
| 40 | + values are strings. |
| 41 | +
|
| 42 | + - Ensures keys are valid Bash variable names. |
| 43 | + - Removes potentially dangerous characters from string values. |
| 44 | +
|
| 45 | + Args: |
| 46 | + env_vars (Dict[str, str]): A dictionary of string key-value pairs. |
| 47 | +
|
| 48 | + Returns: |
| 49 | + Dict[str, str]: A new dictionary with sanitized key-value pairs. |
| 50 | + """ |
| 51 | + sanitized = {} |
| 52 | + |
| 53 | + # Regex for a valid Bash variable name |
| 54 | + valid_key_pattern = re.compile(r'^[a-zA-Z_][a-zA-Z0-9_]*$') |
| 55 | + |
| 56 | + # Characters to remove from values to prevent command injection |
| 57 | + dangerous_chars_pattern = re.compile(r'[;`\'"$()|<>&!\n\r]') |
| 58 | + |
| 59 | + for key, value in env_vars.items(): |
| 60 | + # Step 1: Validate the key. |
| 61 | + if not valid_key_pattern.fullmatch(key): |
| 62 | + print(f"Warning: Invalid environment variable key '{key}' skipped.") |
| 63 | + continue |
| 64 | + |
| 65 | + # Step 2: Sanitize the value. |
| 66 | + sanitized_value = dangerous_chars_pattern.sub('', value) |
| 67 | + sanitized[key] = sanitized_value |
| 68 | + |
| 69 | + return sanitized |
| 70 | + |
| 71 | + |
| 72 | +def create_safe_env(user_env: Dict[str, str]) -> dict: |
| 73 | + """Generate a sanitized environment dict without applying it""" |
| 74 | + sanitized_user_env = sanitize_env_vars(user_env) |
| 75 | + return {**os.environ, **sanitized_user_env} |
| 76 | + |
| 77 | + |
33 | 78 | def generate_run_script(cmdline: str) -> None: |
34 | 79 | """Writes the command used to invoke the cwl-runner to run.sh |
35 | 80 | Does not actually invoke ./run.sh |
@@ -189,6 +234,72 @@ def run_local(run_args_dict: Dict[str, str], use_subprocess: bool, |
189 | 234 | return retval |
190 | 235 |
|
191 | 236 |
|
| 237 | +def run_compute(workflow_name: str, workflow: Json, workflow_inputs: Json, |
| 238 | + user_env_vars: Dict[str, str] = {}, local_instance: bool = False) -> Optional[int]: |
| 239 | + """This function runs the compiled workflow through compute. |
| 240 | +
|
| 241 | + Args: |
| 242 | + workflow_name (str): The name of the workflow |
| 243 | + workflow (Json): The compiled CWL workflow |
| 244 | + workflow_inputs (Json): The inputs for compiled CWL workflow |
| 245 | + user_env_vars (Dict[str,str]): User supplied environment variables |
| 246 | + local_instance (bool): Run on local instance of compute |
| 247 | +
|
| 248 | + Returns: |
| 249 | + retval (Optional[int]): The return value indicating if run succeeded (0) or not |
| 250 | + """ |
| 251 | + # update the environment with user supplied env args |
| 252 | + os.environ.update(sanitize_env_vars(user_env_vars)) |
| 253 | + |
| 254 | + connect_timeout = 5 # in seconds |
| 255 | + read_timeout = 30 # in seconds |
| 256 | + timeout_tuple = (connect_timeout, read_timeout) |
| 257 | + # construct compute_workflow object to be submitted |
| 258 | + # append timestamp to the job/workflow name to create jobid |
| 259 | + now = datetime.now() |
| 260 | + date_time = now.strftime("%Y_%m_%d_%H.%M.%S") |
| 261 | + jobid = workflow_name + '__' + str(date_time) + '__' |
| 262 | + compute_workflow = { |
| 263 | + 'cwlWorkflow': workflow, |
| 264 | + 'cwlJobInputs': workflow_inputs, |
| 265 | + 'id': jobid, |
| 266 | + 'jobs': {} |
| 267 | + } |
| 268 | + |
| 269 | + base_url = '' |
| 270 | + if local_instance: |
| 271 | + # this is for localhost |
| 272 | + base_url = 'http://127.0.0.1:7998/compute/' |
| 273 | + else: |
| 274 | + # this is for actual compute URL |
| 275 | + base_url = 'http://dali-polus.ncats.nih.gov:7998/compute/' |
| 276 | + |
| 277 | + print('Sending request to Compute') |
| 278 | + res = requests.post(base_url, json=compute_workflow, timeout=timeout_tuple) |
| 279 | + print('Post response code: ' + str(res.status_code)) |
| 280 | + |
| 281 | + res = requests.get(base_url + f'{jobid}/outputs/', timeout=timeout_tuple) |
| 282 | + print('Output response code: ' + str(res.status_code)) |
| 283 | + retval = 0 if res.status_code == 200 else 1 |
| 284 | + print('Toil output: ' + str(res.text)) |
| 285 | + |
| 286 | + res = requests.get(base_url + f'{jobid}/logs/', timeout=timeout_tuple) |
| 287 | + # 1. Parse the JSON string into a Python dictionary |
| 288 | + log_data = json.loads(res.text) |
| 289 | + |
| 290 | + # 2. Extract the first key-value pair, which contains the main log content. |
| 291 | + # The key is the filename, and the value is the log text. |
| 292 | + first_key = list(log_data.keys())[0] |
| 293 | + log_content = log_data[first_key] |
| 294 | + |
| 295 | + print('Toil logs: ') |
| 296 | + pprint(log_content, indent=4) |
| 297 | + |
| 298 | + with open(f'compute_logs_{jobid}.txt', 'w', encoding='utf-8') as f: |
| 299 | + f.write(log_content) |
| 300 | + return retval |
| 301 | + |
| 302 | + |
192 | 303 | def copy_output_files(yaml_stem: str, basepath: str = '') -> None: |
193 | 304 | """Copies output files from the cachedir to outdir/ |
194 | 305 |
|
|
0 commit comments