1- from contextlib import contextmanager
2- import functools
3- import json
41from typing import Any , Union , List
5- import os
6- import sys
7- # Avoid conflicts with possible "datetime" imports (in the user code)
8- import time as _time
9- import shlex
10- import warnings
11- import tempfile
122
133# prevent sum from pyskaprk.sql.functions from shadowing the builtin sum
4+ import sys
145builtinSum = sys .modules ['builtins' ].sum
156
167def logError (function_name : str , e : Union [str , Exception ]):
@@ -44,6 +35,8 @@ def disposable(f):
4435
4536
4637def logErrorAndContinue (f ):
38+ import functools
39+
4740 @functools .wraps (f )
4841 def wrapper (* args , ** kwargs ):
4942 try :
@@ -56,6 +49,7 @@ def wrapper(*args, **kwargs):
5649@logErrorAndContinue
5750@disposable
5851def load_env_from_leaf (path : str ) -> bool :
52+ import os
5953 curdir = path if os .path .isdir (path ) else os .path .dirname (path )
6054 env_file_path = os .path .join (curdir , ".databricks" , ".databricks.env" )
6155 if os .path .exists (env_file_path ):
@@ -106,6 +100,7 @@ def __init__(self, env_name: str, default: any = None, required: bool = False):
106100 self .required = required
107101
108102 def __get__ (self , instance , owner ):
103+ import os
109104 if self .env_name in os .environ :
110105 if self .transform is not bool :
111106 return self .transform (os .environ [self .env_name ])
@@ -144,6 +139,7 @@ class DatabricksMagics(Magics):
144139 @needs_local_scope
145140 @line_magic
146141 def fs (self , line : str , local_ns ):
142+ import shlex
147143 args = shlex .split (line )
148144 if len (args ) == 0 :
149145 return
@@ -163,6 +159,7 @@ def fs(self, line: str, local_ns):
163159
164160
165161def is_databricks_notebook (py_file : str ):
162+ import os
166163 if os .path .exists (py_file ):
167164 with open (py_file , "r" ) as f :
168165 return "Databricks notebook source" in f .readline ()
@@ -175,6 +172,9 @@ def strip_hash_magic(lines: List[str]) -> List[str]:
175172 return lines
176173
177174def convert_databricks_notebook_to_ipynb (py_file : str ):
175+ import os
176+ import json
177+
178178 cells : List [dict [str , Any ]] = [
179179 {
180180 "cell_type" : "code" ,
@@ -205,10 +205,13 @@ def convert_databricks_notebook_to_ipynb(py_file: str):
205205 'nbformat_minor' : 2
206206 })
207207
208-
208+
209+ from contextlib import contextmanager
209210@contextmanager
210211def databricks_notebook_exec_env (project_root : str , py_file : str ):
212+ import os
211213 import sys
214+ import tempfile
212215 old_sys_path = sys .path
213216 old_cwd = os .getcwd ()
214217
@@ -229,13 +232,107 @@ def databricks_notebook_exec_env(project_root: str, py_file: str):
229232 os .chdir (old_cwd )
230233
231234
232- @logErrorAndContinue
235+ """
236+ Splits an SQL string into individual statements using recursive descent parsing technique.
237+ Handles semicolons in strings and comments. Most probably breaks in dozens of other edge cases...
238+ """
239+ class SqlStatementParser :
240+ def __init__ (self , sql ):
241+ self .sql = sql
242+ self .position = 0
243+ self .statements = []
244+ self .current = []
245+
246+ def parse (self ):
247+ while self .position < len (self .sql ):
248+ char = self .peek ()
249+ next_char = self .peek_next ()
250+ if char == '-' and next_char == '-' :
251+ self .parse_line_comment ()
252+ elif char == '/' and next_char == '*' :
253+ self .parse_block_comment ()
254+ elif char == "'" :
255+ self .parse_string ("'" )
256+ elif char == '"' :
257+ self .parse_string ('"' )
258+ elif char == '`' :
259+ self .parse_string ('`' )
260+ elif char == ';' :
261+ self .position += 1 # Skip the semicolon itself
262+ self .add_statement ()
263+ else :
264+ self .consume ()
265+ self .add_statement () # Add the last statement if there is one
266+ return self .statements
267+
268+ def peek (self ):
269+ if self .position < len (self .sql ):
270+ return self .sql [self .position ]
271+ return None
272+
273+ def peek_next (self ):
274+ if self .position + 1 < len (self .sql ):
275+ return self .sql [self .position + 1 ]
276+ return None
277+
278+ def consume (self ):
279+ char = self .peek ()
280+ if char is not None :
281+ self .position += 1
282+ self .current .append (char )
283+ return char
284+
285+ def consume_next (self ):
286+ char , next_char = self .peek (), self .peek_next ()
287+ if char is not None and next_char is not None :
288+ self .position += 2
289+ self .current .extend ([char , next_char ])
290+ return char , next_char
291+
292+ def add_statement (self ):
293+ if self .current :
294+ stmt = '' .join (self .current ).strip ()
295+ if stmt :
296+ self .statements .append (stmt )
297+ self .current = []
298+
299+ def parse_line_comment (self ):
300+ self .consume_next () # Consume "--" that starts the comment
301+ while self .peek () is not None :
302+ if self .peek () == '\n ' :
303+ self .consume ()
304+ return
305+ self .consume ()
306+
307+ def parse_block_comment (self ):
308+ self .consume_next () # Consume "/*" that starts the comment
309+ while self .peek () is not None :
310+ if self .peek () == '*' and self .peek_next () == '/' :
311+ self .consume_next () # Consume "*/" that ends the comment
312+ return
313+ self .consume ()
314+
315+ def parse_string (self , quote_char ):
316+ self .consume () # Consume the opening quote
317+ while self .peek () is not None :
318+ # Handle escaped quote
319+ if self .peek () == '\\ ' and self .peek_next () == quote_char :
320+ self .consume_next () # Consume the escaped quote
321+ elif self .peek () == quote_char :
322+ self .consume () # Consume the closing quote
323+ return
324+ else :
325+ self .consume ()
326+
233327@disposable
234- def register_magics (cfg : LocalDatabricksNotebookConfig ):
328+ def create_databricks_magics_transformer (cfg : LocalDatabricksNotebookConfig ):
329+ import os
330+ import warnings
331+
235332 def warn_for_dbr_alternative (magic : str ):
236333 # Magics that are not supported on Databricks but work in jupyter notebooks.
237334 # We show a warning, prompting users to use a databricks equivalent instead.
238- local_magic_dbr_alternative = {"%%sh" : "% sh" }
335+ local_magic_dbr_alternative = {"%%sh" : "sh" }
239336 if magic in local_magic_dbr_alternative :
240337 warnings .warn (
241338 "\n " + magic
@@ -247,7 +344,7 @@ def warn_for_dbr_alternative(magic: str):
247344
248345 def throw_if_not_supported (magic : str ):
249346 # These are magics that are supported on dbr but not locally.
250- unsupported_dbr_magics = ["% r" , "% scala" ]
347+ unsupported_dbr_magics = ["r" , "scala" ]
251348 if magic in unsupported_dbr_magics :
252349 raise NotImplementedError (
253350 magic
@@ -300,14 +397,14 @@ def handle(lines: List[str]):
300397
301398 if lmagic == "sql" :
302399 lines = lines [1 :]
303- spark_string = (
304- "global _sqldf \n "
305- + "_sqldf = spark.sql('''"
306- + "" . join ( lines ). replace ( "'" , " \\ '" )
307- + "''') \n "
308- + "_sqldf"
309- )
310- return spark_string . splitlines ( keepends = True )
400+ sql_string = "" . join ( lines )
401+ statements = SqlStatementParser ( sql_string ). parse ()
402+ result_code = [ "global _sqldf \n " ]
403+ for _ , stmt in enumerate ( statements ):
404+ quoted_stmt = stmt . replace ( "'" , " \\ '" )
405+ result_code . append ( f "_sqldf = spark.sql(''' { quoted_stmt } ''') \n " )
406+ result_code . append ( "_sqldf" )
407+ return result_code
311408
312409 if lmagic == "python" :
313410 return lines [1 :]
@@ -317,8 +414,9 @@ def handle(lines: List[str]):
317414 if len (rest ) == 0 :
318415 return lines
319416
417+ raw_filename = rest [0 ]
320418 # Strip whitespace or possible quotes around the filename
321- filename = rest [ 0 ] .strip ('\' " ' )
419+ filename = raw_filename .strip ('\' " ' )
322420
323421 for suffix in ["" , ".py" , ".ipynb" , ".ipy" ]:
324422 if os .path .exists (os .path .join (os .getcwd (), filename + suffix )):
@@ -327,15 +425,14 @@ def handle(lines: List[str]):
327425
328426 return [
329427 f"with databricks_notebook_exec_env(r'{ cfg .project_root } ', r'{ filename } ') as file:\n " ,
330- "\t %run -i {file} " + lines [0 ].partition ('%run' )[2 ].partition (filename )[2 ] + "\n "
428+ "\t %run -i {file} " + lines [0 ].partition ('%run' )[2 ].partition (raw_filename )[2 ]. strip () + "\n "
331429 ]
332430
333431 return lines
334432
335433 is_line_magic .handle = handle
336434 return get_line_magic (lines ) is not None
337435
338-
339436 def parse_line_for_databricks_magics (lines : List [str ]):
340437 if len (lines ) == 0 :
341438 return lines
@@ -350,10 +447,16 @@ def parse_line_for_databricks_magics(lines: List[str]):
350447 return magic_check .handle (lines )
351448
352449 return lines
450+
451+ return parse_line_for_databricks_magics
452+
353453
454+ @logErrorAndContinue
455+ @disposable
456+ def register_magics (cfg : LocalDatabricksNotebookConfig ):
354457 ip = get_ipython ()
355458 ip .register_magics (DatabricksMagics )
356- ip .input_transformers_cleanup .append (parse_line_for_databricks_magics )
459+ ip .input_transformers_cleanup .append (create_databricks_magics_transformer ( cfg ) )
357460
358461
359462@logErrorAndContinue
@@ -372,6 +475,7 @@ def df_html(df):
372475@logErrorAndContinue
373476@disposable
374477def register_spark_progress (spark , show_progress : bool ):
478+ import time
375479 try :
376480 import ipywidgets as widgets
377481 except Exception as e :
@@ -389,7 +493,7 @@ def __init__(
389493 ) -> None :
390494 self ._ticks = None
391495 self ._tick = None
392- self ._started = _time .time ()
496+ self ._started = time .time ()
393497 self ._bytes_read = 0
394498 self ._running = 0
395499 self .init_ui ()
@@ -430,7 +534,7 @@ def update_ticks(
430534 def output (self ) -> None :
431535 if self ._tick is not None and self ._ticks is not None :
432536 percent_complete = (self ._tick / self ._ticks ) * 100
433- elapsed = int (_time .time () - self ._started )
537+ elapsed = int (time .time () - self ._started )
434538 scanned = self ._bytes_to_string (self ._bytes_read )
435539 running = self ._running
436540 self .w_progress .value = percent_complete
@@ -475,6 +579,7 @@ def __call__(self,
475579@logErrorAndContinue
476580@disposable
477581def update_sys_path (notebook_config : LocalDatabricksNotebookConfig ):
582+ import sys
478583 sys .path .append (notebook_config .project_root )
479584
480585
@@ -486,14 +591,14 @@ def make_matplotlib_inline():
486591 except Exception as e :
487592 pass
488593
489-
490- global _sqldf
491-
492- try :
594+ @ logErrorAndContinue
595+ @ disposable
596+ def setup ():
597+ import os
493598 import sys
494-
495599 print (sys .modules [__name__ ])
496600
601+ global _sqldf
497602 # Suppress grpc warnings coming from databricks-connect with newer version of grpcio lib
498603 os .environ ["GRPC_VERBOSITY" ] = "NONE"
499604
@@ -515,8 +620,9 @@ def make_matplotlib_inline():
515620
516621 for i in __disposables + ['__disposables' ]:
517622 globals ().pop (i )
518- globals ().pop ('i' )
519623 globals ().pop ('disposable' )
520624
521- except Exception as e :
522- logError ("unknown" , e )
625+
626+ import os
627+ if not os .environ .get ("DATABRICKS_EXTENSION_UNIT_TESTS" ):
628+ setup ()
0 commit comments