|
3 | 3 | import logging |
4 | 4 | import uuid |
5 | 5 | import asyncio |
| 6 | +import textwrap |
6 | 7 |
|
7 | 8 | from asyncio import Queue |
8 | 9 | from typing import ( |
@@ -177,6 +178,32 @@ def _reset_env_vars_code(self, env_vars: Dict[StrictStr, str]) -> str: |
177 | 178 |
|
178 | 179 | return "\n".join(cleanup_commands) |
179 | 180 |
|
| 181 | + def _get_code_indentation(self, code: str) -> str: |
| 182 | + """Get the indentation from the first non-empty line of code.""" |
| 183 | + if not code or not code.strip(): |
| 184 | + return "" |
| 185 | + |
| 186 | + lines = code.split('\n') |
| 187 | + for line in lines: |
| 188 | + if line.strip(): # First non-empty line |
| 189 | + return line[:len(line) - len(line.lstrip())] |
| 190 | + |
| 191 | + return "" |
| 192 | + |
| 193 | + def _indent_code_with_level(self, code: str, indent_level: str) -> str: |
| 194 | + """Apply the given indentation level to each line of code.""" |
| 195 | + if not code or not indent_level: |
| 196 | + return code |
| 197 | + |
| 198 | + lines = code.split('\n') |
| 199 | + indented_lines = [] |
| 200 | + |
| 201 | + for line in lines: |
| 202 | + if line.strip(): # Non-empty lines |
| 203 | + indented_lines.append(indent_level + line) |
| 204 | + |
| 205 | + return '\n'.join(indented_lines) |
| 206 | + |
180 | 207 | async def _cleanup_env_vars(self, env_vars: Dict[StrictStr, str]): |
181 | 208 | """Clean up environment variables in a separate execution request.""" |
182 | 209 | message_id = str(uuid.uuid4()) |
@@ -258,21 +285,28 @@ async def execute( |
258 | 285 | raise Exception("WebSocket not connected") |
259 | 286 |
|
260 | 287 | async with self._lock: |
| 288 | + # Get the indentation level from the code |
| 289 | + code_indent = self._get_code_indentation(code) |
| 290 | + |
261 | 291 | # Build the complete code snippet with env vars |
262 | 292 | complete_code = code |
263 | 293 |
|
264 | 294 | if not self.global_env_vars: |
265 | 295 | self.global_env_vars = await get_envs() |
266 | 296 |
|
267 | 297 | if not self.global_env_vars_set and self.global_env_vars: |
268 | | - complete_code = f"{self._set_env_vars_code(self.global_env_vars)}\n{complete_code}" |
| 298 | + env_setup_code = self._set_env_vars_code(self.global_env_vars) |
| 299 | + if env_setup_code: |
| 300 | + indented_env_code = self._indent_code_with_level(env_setup_code, code_indent) |
| 301 | + complete_code = f"{indented_env_code}\n{complete_code}" |
269 | 302 | self.global_env_vars_set = True |
270 | 303 |
|
271 | 304 | if env_vars: |
272 | 305 | # Add env var setup at the beginning |
273 | 306 | env_setup_code = self._set_env_vars_code(env_vars) |
274 | 307 | if env_setup_code: |
275 | | - complete_code = f"{env_setup_code}\n{complete_code}" |
| 308 | + indented_env_code = self._indent_code_with_level(env_setup_code, code_indent) |
| 309 | + complete_code = f"{indented_env_code}\n{complete_code}" |
276 | 310 |
|
277 | 311 | logger.info(f"Executing complete code: {complete_code}") |
278 | 312 | request = self._get_execute_request(message_id, complete_code, False) |
@@ -448,7 +482,8 @@ async def close(self): |
448 | 482 | if self._ws is not None: |
449 | 483 | await self._ws.close() |
450 | 484 |
|
451 | | - self._receive_task.cancel() |
| 485 | + if self._receive_task is not None: |
| 486 | + self._receive_task.cancel() |
452 | 487 |
|
453 | 488 | for execution in self._executions.values(): |
454 | 489 | execution.queue.put_nowait(UnexpectedEndOfExecution()) |
0 commit comments