|
| 1 | +import libcst as cst |
| 2 | +import tempfile |
| 3 | +import os |
| 4 | +import shutil |
| 5 | +from typing import Dict, Optional |
| 6 | + |
| 7 | +class Injector: |
| 8 | + """ |
| 9 | + Unified injector for Python files using libcst. |
| 10 | + Handles both variable and function injection. |
| 11 | + All operations work on temp files automatically. |
| 12 | + """ |
| 13 | + |
| 14 | + def __init__(self, python_file_path: str = None, code: str = None, |
| 15 | + module: cst.Module = None, filename: str = "script.py"): |
| 16 | + """ |
| 17 | + Args: |
| 18 | + python_file_path: Path to original Python file (read-only, copied to temp) |
| 19 | + code: Python code as string (alternative to file_path) |
| 20 | + module: CST Module object (most efficient - no parsing needed) |
| 21 | + filename: Name for temp file (used when code/module is provided) |
| 22 | + """ |
| 23 | + # Validate arguments |
| 24 | + provided = sum([bool(python_file_path), bool(code), bool(module)]) |
| 25 | + if provided > 1: |
| 26 | + raise ValueError("Cannot provide multiple sources (python_file_path, code, or module)") |
| 27 | + if provided == 0: |
| 28 | + raise ValueError("Must provide either python_file_path, code, or module") |
| 29 | + |
| 30 | + # Get module from file, string, or use provided CST Module |
| 31 | + if python_file_path: |
| 32 | + self.python_file_path = python_file_path |
| 33 | + # Read original file (read-only) |
| 34 | + with open(python_file_path, 'r', encoding='utf-8') as f: |
| 35 | + self._original_code = f.read() |
| 36 | + temp_filename = os.path.basename(python_file_path) |
| 37 | + # Parse using libcst |
| 38 | + self._module = cst.parse_module(self._original_code) |
| 39 | + elif module: |
| 40 | + self.python_file_path = None |
| 41 | + # Use provided CST Module (no parsing needed!) |
| 42 | + self._module = module |
| 43 | + self._original_code = module.code |
| 44 | + temp_filename = filename |
| 45 | + else: # code string |
| 46 | + self.python_file_path = None |
| 47 | + self._original_code = code |
| 48 | + # Parse using libcst |
| 49 | + self._module = cst.parse_module(code) |
| 50 | + temp_filename = filename |
| 51 | + |
| 52 | + # Create temp directory and file immediately |
| 53 | + self._temp_dir = tempfile.mkdtemp() |
| 54 | + self._temp_file_path = os.path.join(self._temp_dir, temp_filename) |
| 55 | + |
| 56 | + # Write initial copy to temp file |
| 57 | + with open(self._temp_file_path, 'w', encoding='utf-8') as f: |
| 58 | + f.write(self._original_code) |
| 59 | + |
| 60 | + # File pointer is closed, all future ops use temp file |
| 61 | + |
| 62 | + def _create_value_node(self, value): |
| 63 | + """Helper to create CST value node from Python value""" |
| 64 | + type_map = { |
| 65 | + str: lambda v: cst.SimpleString(f'"{v}"'), |
| 66 | + int: lambda v: cst.Integer(str(v)), |
| 67 | + float: lambda v: cst.Float(str(v)), |
| 68 | + bool: lambda v: cst.Name("True" if v else "False"), |
| 69 | + type(None): lambda v: cst.Name("None"), |
| 70 | + } |
| 71 | + return type_map.get(type(value), lambda v: cst.SimpleString(f'"{str(v)}"'))(value) |
| 72 | + |
| 73 | + def inject_variables(self, variables: Dict[str, any]): |
| 74 | + """ |
| 75 | + Inject variable assignments into the file. |
| 76 | + |
| 77 | + Args: |
| 78 | + variables: Dictionary of variable names and values |
| 79 | + at_top: If True, injects at top of file; if False, at end |
| 80 | + |
| 81 | + Returns: |
| 82 | + self (for chaining) |
| 83 | + """ |
| 84 | + assignments = [ |
| 85 | + cst.SimpleStatementLine(body=[ |
| 86 | + cst.Assign( |
| 87 | + targets=[cst.AssignTarget(target=cst.Name(var_name))], |
| 88 | + value=self._create_value_node(var_value) |
| 89 | + ) |
| 90 | + ]) |
| 91 | + for var_name, var_value in variables.items() |
| 92 | + ] |
| 93 | + |
| 94 | + # Apply transformation directly by modifying module body |
| 95 | + new_body = list(self._module.body) |
| 96 | + # Insert at beginning |
| 97 | + new_body = assignments + new_body |
| 98 | + |
| 99 | + self._module = self._module.with_changes(body=new_body) |
| 100 | + self._save_to_temp() |
| 101 | + |
| 102 | + return self |
| 103 | + |
| 104 | + def add_dependency(self, packages: list): |
| 105 | + """ |
| 106 | + Add pip install command at the top of the file using os.system. |
| 107 | + Also ensures 'import os' is present. |
| 108 | + |
| 109 | + Args: |
| 110 | + packages: List of package names to install |
| 111 | + |
| 112 | + Returns: |
| 113 | + self (for chaining) |
| 114 | + """ |
| 115 | + if not packages: |
| 116 | + return self |
| 117 | + |
| 118 | + # Check if 'import os' already exists |
| 119 | + has_os_import = False |
| 120 | + for item in self._module.body: |
| 121 | + if isinstance(item, cst.SimpleStatementLine): |
| 122 | + for stmt in item.body: |
| 123 | + if isinstance(stmt, cst.Import): |
| 124 | + for alias in stmt.names: |
| 125 | + if alias.name.value == 'os': |
| 126 | + has_os_import = True |
| 127 | + break |
| 128 | + elif isinstance(stmt, cst.ImportFrom) and stmt.module and stmt.module.value == 'os': |
| 129 | + has_os_import = True |
| 130 | + break |
| 131 | + |
| 132 | + # Create pip install command |
| 133 | + packages_str = ' '.join(packages) |
| 134 | + pip_command = f'pip install {packages_str}' |
| 135 | + |
| 136 | + # Create os.system call |
| 137 | + os_system_call = cst.SimpleStatementLine(body=[ |
| 138 | + cst.Expr(value=cst.Call( |
| 139 | + func=cst.Attribute( |
| 140 | + value=cst.Name('os'), |
| 141 | + attr=cst.Name('system') |
| 142 | + ), |
| 143 | + args=[cst.Arg(value=cst.SimpleString(f'"{pip_command}"'))] |
| 144 | + )) |
| 145 | + ]) |
| 146 | + |
| 147 | + # Build new body |
| 148 | + new_body = list(self._module.body) |
| 149 | + |
| 150 | + # Add import os if not present |
| 151 | + if not has_os_import: |
| 152 | + os_import = cst.SimpleStatementLine(body=[ |
| 153 | + cst.Import(names=[cst.ImportAlias(name=cst.Name('os'))]) |
| 154 | + ]) |
| 155 | + new_body.insert(0, os_import) |
| 156 | + # Insert os.system call after import |
| 157 | + new_body.insert(1, os_system_call) |
| 158 | + else: |
| 159 | + # Just insert os.system call at top |
| 160 | + new_body.insert(0, os_system_call) |
| 161 | + |
| 162 | + self._module = self._module.with_changes(body=new_body) |
| 163 | + self._save_to_temp() |
| 164 | + |
| 165 | + return self |
| 166 | + |
| 167 | + def inject_function(self, code: str, func_name: str): |
| 168 | + """ |
| 169 | + Inject code into existing function's body by replacing its body content. |
| 170 | + |
| 171 | + Args: |
| 172 | + code: Python code string to inject into function body |
| 173 | + func_name: Name of the existing function to modify |
| 174 | + |
| 175 | + Returns: |
| 176 | + self (for chaining) |
| 177 | + """ |
| 178 | + # Parse injected code as module to get statements |
| 179 | + injected_module = cst.parse_module(code) |
| 180 | + body_statements = list(injected_module.body) |
| 181 | + |
| 182 | + # Replace function body directly |
| 183 | + new_body = [ |
| 184 | + item.with_changes(body=cst.IndentedBlock(body=body_statements)) |
| 185 | + if isinstance(item, cst.FunctionDef) and item.name.value == func_name |
| 186 | + else item |
| 187 | + for item in self._module.body |
| 188 | + ] |
| 189 | + self._module = self._module.with_changes(body=new_body) |
| 190 | + |
| 191 | + self._save_to_temp() |
| 192 | + return self |
| 193 | + |
| 194 | + def _save_to_temp(self): |
| 195 | + """Internal: Save modified code to temp file""" |
| 196 | + with open(self._temp_file_path, 'w', encoding='utf-8') as f: |
| 197 | + f.write(self._module.code) |
| 198 | + |
| 199 | + def get_temp_file_path(self) -> str: |
| 200 | + """Get path to temporary file""" |
| 201 | + return self._temp_file_path |
| 202 | + |
| 203 | + def get_temp_dir(self) -> str: |
| 204 | + """Get path to temporary directory""" |
| 205 | + return self._temp_dir |
| 206 | + |
| 207 | + def get_code(self) -> str: |
| 208 | + """Get the modified code as string (for inspection)""" |
| 209 | + return self._module.code |
| 210 | + |
| 211 | + def destroy(self): |
| 212 | + """ |
| 213 | + Destroy all temporary files and directory. |
| 214 | + Call this when done with temp files. |
| 215 | + """ |
| 216 | + if self._temp_dir and os.path.exists(self._temp_dir): |
| 217 | + shutil.rmtree(self._temp_dir) |
| 218 | + self._temp_dir = None |
| 219 | + self._temp_file_path = None |
| 220 | + |
| 221 | + def __del__(self): |
| 222 | + """Automatically clean up temp files when object is destroyed""" |
| 223 | + # Only destroy if temp_dir still exists (destroy() not already called) |
| 224 | + if hasattr(self, '_temp_dir') and self._temp_dir and os.path.exists(self._temp_dir): |
| 225 | + try: |
| 226 | + shutil.rmtree(self._temp_dir) |
| 227 | + except (OSError, AttributeError): |
| 228 | + # Ignore errors during destruction (temp files may already be cleaned up) |
| 229 | + pass |
| 230 | + |
| 231 | + def __enter__(self): |
| 232 | + """Context manager support""" |
| 233 | + return self |
| 234 | + |
| 235 | + def __exit__(self, exc_type, exc_val, exc_tb): |
| 236 | + """Context manager cleanup""" |
| 237 | + self.destroy() |
0 commit comments