2323    UnexpectedEndOfExecution ,
2424)
2525from  errors  import  ExecutionError 
26+ from  envs  import  get_envs 
2627
2728logger  =  logging .getLogger (__name__ )
2829
@@ -47,7 +48,8 @@ def __init__(self, in_background: bool = False):
4748class  ContextWebSocket :
4849    _ws : Optional [WebSocketClientProtocol ] =  None 
4950    _receive_task : Optional [asyncio .Task ] =  None 
50-     global_env_vars : Optional [Dict [StrictStr , str ]] =  None 
51+     _global_env_vars : Optional [Dict [StrictStr , str ]] =  None 
52+     _cleanup_task : Optional [asyncio .Task ] =  None 
5153
5254    def  __init__ (
5355        self ,
@@ -114,6 +116,113 @@ def _get_execute_request(
114116            }
115117        )
116118
119+     def  _set_env_var_snippet (self , key : str , value : str ) ->  str :
120+         """Get environment variable set command for the current language.""" 
121+         if  self .language  ==  "python" :
122+             return  f"import os; os.environ['{ key } { value }  
123+         elif  self .language  in  ["javascript" , "typescript" ]:
124+             return  f"process.env['{ key } { value }  
125+         elif  self .language  ==  "deno" :
126+             return  f"Deno.env.set('{ key } { value }  
127+         elif  self .language  ==  "r" :
128+             return  f'Sys.setenv({ key } { value }  
129+         elif  self .language  ==  "java" :
130+             return  f'System.setProperty("{ key } { value }  
131+         elif  self .language  ==  "bash" :
132+             return  f"export { key } { value }  
133+         return  "" 
134+ 
135+     def  _delete_env_var_snippet (self , key : str ) ->  str :
136+         """Get environment variable delete command for the current language.""" 
137+         if  self .language  ==  "python" :
138+             return  f"import os; del os.environ['{ key }  
139+         elif  self .language  in  ["javascript" , "typescript" ]:
140+             return  f"delete process.env['{ key }  
141+         elif  self .language  ==  "deno" :
142+             return  f"Deno.env.delete('{ key }  
143+         elif  self .language  ==  "r" :
144+             return  f"Sys.unsetenv('{ key }  
145+         elif  self .language  ==  "java" :
146+             return  f'System.clearProperty("{ key }  
147+         elif  self .language  ==  "bash" :
148+             return  f"unset { key }  
149+         return  "" 
150+ 
151+     def  _set_env_vars_code (self , env_vars : Dict [StrictStr , str ]) ->  str :
152+         """Build environment variable code for the current language.""" 
153+         env_commands  =  []
154+         for  k , v  in  env_vars .items ():
155+             command  =  self ._set_env_var_snippet (k , v )
156+             if  command :
157+                 env_commands .append (command )
158+         
159+         return  "\n " .join (env_commands )
160+ 
161+     def  _reset_env_vars_code (self , env_vars : Dict [StrictStr , str ]) ->  str :
162+         """Build environment variable cleanup code for the current language.""" 
163+         cleanup_commands  =  []
164+         
165+         for  key  in  env_vars :
166+             # Check if this var exists in global env vars 
167+             if  self ._global_env_vars  and  key  in  self ._global_env_vars :
168+                 # Reset to global value 
169+                 value  =  self ._global_env_vars [key ]
170+                 command  =  self ._set_env_var_snippet (key , value )
171+             else :
172+                 # Remove the variable 
173+                 command  =  self ._delete_env_var_snippet (key )
174+             
175+             if  command :
176+                 cleanup_commands .append (command )
177+         
178+         return  "\n " .join (cleanup_commands )
179+ 
180+     def  _get_code_indentation (self , code : str ) ->  str :
181+         """Get the indentation from the first non-empty line of code.""" 
182+         if  not  code  or  not  code .strip ():
183+             return  "" 
184+         
185+         lines  =  code .split ('\n ' )
186+         for  line  in  lines :
187+             if  line .strip ():  # First non-empty line 
188+                 return  line [:len (line ) -  len (line .lstrip ())]
189+         
190+         return  "" 
191+ 
192+     def  _indent_code_with_level (self , code : str , indent_level : str ) ->  str :
193+         """Apply the given indentation level to each line of code.""" 
194+         if  not  code  or  not  indent_level :
195+             return  code 
196+         
197+         lines  =  code .split ('\n ' )
198+         indented_lines  =  []
199+         
200+         for  line  in  lines :
201+             if  line .strip ():  # Non-empty lines 
202+                 indented_lines .append (indent_level  +  line )
203+             else :
204+                 indented_lines .append (line )
205+         
206+         return  '\n ' .join (indented_lines )
207+ 
208+     async  def  _cleanup_env_vars (self , env_vars : Dict [StrictStr , str ]):
209+         """Clean up environment variables in a separate execution request.""" 
210+         message_id  =  str (uuid .uuid4 ())
211+         self ._executions [message_id ] =  Execution (in_background = True )
212+ 
213+         try :
214+             cleanup_code  =  self ._reset_env_vars_code (env_vars )
215+             if  cleanup_code :
216+                 logger .info (f"Cleaning up env vars: { cleanup_code }  )
217+                 request  =  self ._get_execute_request (message_id , cleanup_code , True )
218+                 await  self ._ws .send (request )
219+ 
220+                 async  for  item  in  self ._wait_for_result (message_id ):
221+                     if  item ["type" ] ==  "error" :
222+                         logger .error (f"Error during env var cleanup: { item }  )
223+         finally :
224+             del  self ._executions [message_id ]
225+ 
117226    async  def  _wait_for_result (self , message_id : str ):
118227        queue  =  self ._executions [message_id ].queue 
119228
@@ -133,84 +242,6 @@ async def _wait_for_result(self, message_id: str):
133242
134243            yield  output .model_dump (exclude_none = True )
135244
136-     async  def  set_env_vars (self , env_vars : Dict [StrictStr , str ]):
137-         message_id  =  str (uuid .uuid4 ())
138-         self ._executions [message_id ] =  Execution (in_background = True )
139- 
140-         env_commands  =  []
141-         for  k , v  in  env_vars .items ():
142-             if  self .language  ==  "python" :
143-                 env_commands .append (f"import os; os.environ['{ k } { v }  )
144-             elif  self .language  in  ["javascript" , "typescript" ]:
145-                 env_commands .append (f"process.env['{ k } { v }  )
146-             elif  self .language  ==  "deno" :
147-                 env_commands .append (f"Deno.env.set('{ k } { v }  )
148-             elif  self .language  ==  "r" :
149-                 env_commands .append (f'Sys.setenv({ k } { v }  )
150-             elif  self .language  ==  "java" :
151-                 env_commands .append (f'System.setProperty("{ k } { v }  )
152-             elif  self .language  ==  "bash" :
153-                 env_commands .append (f"export { k } { v }  )
154-             else :
155-                 return 
156- 
157-         if  env_commands :
158-             env_vars_snippet  =  "\n " .join (env_commands )
159-             logger .info (f"Setting env vars: { env_vars_snippet } { self .language }  )
160-             request  =  self ._get_execute_request (message_id , env_vars_snippet , True )
161-             await  self ._ws .send (request )
162- 
163-             async  for  item  in  self ._wait_for_result (message_id ):
164-                 if  item ["type" ] ==  "error" :
165-                     raise  ExecutionError (f"Error during execution: { item }  )
166- 
167-     async  def  reset_env_vars (self , env_vars : Dict [StrictStr , str ]):
168-         # Create a dict of vars to reset and a list of vars to remove 
169-         vars_to_reset  =  {}
170-         vars_to_remove  =  []
171- 
172-         for  key  in  env_vars :
173-             if  self .global_env_vars  and  key  in  self .global_env_vars :
174-                 vars_to_reset [key ] =  self .global_env_vars [key ]
175-             else :
176-                 vars_to_remove .append (key )
177- 
178-         # Reset vars that exist in global env vars 
179-         if  vars_to_reset :
180-             await  self .set_env_vars (vars_to_reset )
181- 
182-         # Remove vars that don't exist in global env vars 
183-         if  vars_to_remove :
184-             message_id  =  str (uuid .uuid4 ())
185-             self ._executions [message_id ] =  Execution (in_background = True )
186- 
187-             remove_commands  =  []
188-             for  key  in  vars_to_remove :
189-                 if  self .language  ==  "python" :
190-                     remove_commands .append (f"import os; del os.environ['{ key }  )
191-                 elif  self .language  in  ["javascript" , "typescript" ]:
192-                     remove_commands .append (f"delete process.env['{ key }  )
193-                 elif  self .language  ==  "deno" :
194-                     remove_commands .append (f"Deno.env.delete('{ key }  )
195-                 elif  self .language  ==  "r" :
196-                     remove_commands .append (f"Sys.unsetenv('{ key }  )
197-                 elif  self .language  ==  "java" :
198-                     remove_commands .append (f'System.clearProperty("{ key }  )
199-                 elif  self .language  ==  "bash" :
200-                     remove_commands .append (f"unset { key }  )
201-                 else :
202-                     return 
203-             
204-             if  remove_commands :
205-                 remove_snippet  =  "\n " .join (remove_commands )
206-                 logger .info (f"Removing env vars: { remove_snippet } { self .language }  )
207-                 request  =  self ._get_execute_request (message_id , remove_snippet , True )
208-                 await  self ._ws .send (request )
209- 
210-                 async  for  item  in  self ._wait_for_result (message_id ):
211-                     if  item ["type" ] ==  "error" :
212-                         raise  ExecutionError (f"Error during execution: { item }  )
213- 
214245    async  def  change_current_directory (
215246        self , path : Union [str , StrictStr ], language : str 
216247    ):
@@ -248,20 +279,44 @@ async def execute(
248279        env_vars : Dict [StrictStr , str ] =  None ,
249280    ):
250281        message_id  =  str (uuid .uuid4 ())
251-         logger .debug (f"Sending code for the execution ({ message_id } { code }  )
252- 
253282        self ._executions [message_id ] =  Execution ()
254283
255284        if  self ._ws  is  None :
256285            raise  Exception ("WebSocket not connected" )
257286
258287        async  with  self ._lock :
259-             # set env vars (will override global env vars) 
288+             # Wait for any pending cleanup task to complete 
289+             if  self ._cleanup_task  and  not  self ._cleanup_task .done ():
290+                 logger .debug ("Waiting for pending cleanup task to complete" )
291+                 try :
292+                     await  self ._cleanup_task 
293+                 except  Exception  as  e :
294+                     logger .warning (f"Cleanup task failed: { e }  )
295+                 finally :
296+                     self ._cleanup_task  =  None 
297+             
298+             # Get the indentation level from the code 
299+             code_indent  =  self ._get_code_indentation (code )
300+             
301+             # Build the complete code snippet with env vars 
302+             complete_code  =  code 
303+             
304+             global_env_vars_snippet  =  "" 
305+             env_vars_snippet  =  "" 
306+ 
307+             if  self ._global_env_vars  is  None :
308+                 self ._global_env_vars  =  await  get_envs ()
309+                 global_env_vars_snippet  =  self ._set_env_vars_code (self ._global_env_vars )
310+             
260311            if  env_vars :
261-                 await   self .set_env_vars (env_vars )
312+                 env_vars_snippet   =   self ._set_env_vars_code (env_vars )
262313
263-             logger .info (code )
264-             request  =  self ._get_execute_request (message_id , code , False )
314+             if  global_env_vars_snippet  or  env_vars_snippet :
315+                 indented_env_code  =  self ._indent_code_with_level (f"{ global_env_vars_snippet } \n { env_vars_snippet }  , code_indent )
316+                 complete_code  =  f"{ indented_env_code } \n { complete_code }  
317+ 
318+             logger .info (f"Sending code for the execution ({ message_id } { complete_code }  )
319+             request  =  self ._get_execute_request (message_id , complete_code , False )
265320
266321            # Send the code for execution 
267322            await  self ._ws .send (request )
@@ -272,9 +327,9 @@ async def execute(
272327
273328            del  self ._executions [message_id ]
274329
275-             # reset  env vars to their previous values, if they were set globally or remove them  
330+             # Clean up  env vars in a separate request after the main code has run  
276331            if  env_vars :
277-                 await   self .reset_env_vars (env_vars )
332+                 self . _cleanup_task   =   asyncio . create_task ( self ._cleanup_env_vars (env_vars ) )
278333
279334    async  def  _receive_message (self ):
280335        if  not  self ._ws :
@@ -434,7 +489,16 @@ async def close(self):
434489        if  self ._ws  is  not None :
435490            await  self ._ws .close ()
436491
437-         self ._receive_task .cancel ()
492+         if  self ._receive_task  is  not None :
493+             self ._receive_task .cancel ()
494+ 
495+         # Cancel any pending cleanup task 
496+         if  self ._cleanup_task  and  not  self ._cleanup_task .done ():
497+             self ._cleanup_task .cancel ()
498+             try :
499+                 await  self ._cleanup_task 
500+             except  asyncio .CancelledError :
501+                 pass 
438502
439503        for  execution  in  self ._executions .values ():
440504            execution .queue .put_nowait (UnexpectedEndOfExecution ())
0 commit comments