Skip to content

Commit 332c2f2

Browse files
authored
SQL multi statements + python tests (#1623)
## Changes - Add basic sql-statement parser and execute each statements separtely. `spark.sql` sadly can't do it for us. - Move imports into function scopes (so they can't be shadowed by user imports) - Add tests for magic parsing Fixes #1610 ## Tests Tests!
1 parent e8f690e commit 332c2f2

File tree

8 files changed

+428
-38
lines changed

8 files changed

+428
-38
lines changed

.github/workflows/unit-tests.yml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,3 +63,16 @@ jobs:
6363
with:
6464
run: yarn run test:cov
6565
working-directory: packages/databricks-vscode
66+
67+
- name: Install Python
68+
uses: actions/setup-python@v5
69+
with:
70+
python-version: "3.12" # 3.13+ is not yet supported by the latest DBR
71+
72+
- name: Install Python dependencies
73+
run: pip install ipython
74+
working-directory: packages/databricks-vscode
75+
76+
- name: Python Unit Tests
77+
run: yarn run test:python
78+
working-directory: packages/databricks-vscode

packages/databricks-vscode/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ extension/
66
.pytest_cache/
77
.build/
88
**/tmp/**
9+
__pycache__/
910

1011
# Telemetry file, automatically generated by packages/databricks-vscode/scripts/generateTelemetry.ts
1112
telemetry.json

packages/databricks-vscode/.vscode/settings.json

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,14 @@
1515
},
1616
"[xml]": {
1717
"editor.defaultFormatter": "redhat.vscode-xml"
18-
}
18+
},
19+
"python.testing.unittestArgs": [
20+
"-v",
21+
"-s",
22+
"./src/test/python",
23+
"-p",
24+
"*test.py"
25+
],
26+
"python.testing.pytestEnabled": false,
27+
"python.testing.unittestEnabled": true
1928
}

packages/databricks-vscode/package.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,6 +1174,7 @@
11741174
"fix": "eslint src --ext ts --fix && prettier . --write",
11751175
"test:lint": "eslint src --ext ts && prettier . -c",
11761176
"test:unit": "yarn run build && node ./out/test/runTest.js",
1177+
"test:python": "DATABRICKS_EXTENSION_UNIT_TESTS=1 python -B -m unittest discover -s ./src/test/python -p '*_test.py'",
11771178
"test:integ:prepare": "yarn run package",
11781179
"test:integ:extension": "yarn run test:integ:prepare && wdio run src/test/e2e/wdio.conf.ts",
11791180
"test:integ:sdk": "ts-mocha --type-check 'src/sdk-extensions/**/*.integ.ts'",

packages/databricks-vscode/resources/python/00-databricks-init.py

Lines changed: 143 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,7 @@
1-
from contextlib import contextmanager
2-
import functools
3-
import json
41
from typing import Any, Union, List
5-
import os
6-
import sys
7-
# Avoid conflicts with possible "datetime" imports (in the user code)
8-
import time as _time
9-
import shlex
10-
import warnings
11-
import tempfile
122

133
# prevent sum from pyskaprk.sql.functions from shadowing the builtin sum
4+
import sys
145
builtinSum = sys.modules['builtins'].sum
156

167
def logError(function_name: str, e: Union[str, Exception]):
@@ -44,6 +35,8 @@ def disposable(f):
4435

4536

4637
def logErrorAndContinue(f):
38+
import functools
39+
4740
@functools.wraps(f)
4841
def wrapper(*args, **kwargs):
4942
try:
@@ -56,6 +49,7 @@ def wrapper(*args, **kwargs):
5649
@logErrorAndContinue
5750
@disposable
5851
def load_env_from_leaf(path: str) -> bool:
52+
import os
5953
curdir = path if os.path.isdir(path) else os.path.dirname(path)
6054
env_file_path = os.path.join(curdir, ".databricks", ".databricks.env")
6155
if os.path.exists(env_file_path):
@@ -106,6 +100,7 @@ def __init__(self, env_name: str, default: any = None, required: bool = False):
106100
self.required = required
107101

108102
def __get__(self, instance, owner):
103+
import os
109104
if self.env_name in os.environ:
110105
if self.transform is not bool:
111106
return self.transform(os.environ[self.env_name])
@@ -144,6 +139,7 @@ class DatabricksMagics(Magics):
144139
@needs_local_scope
145140
@line_magic
146141
def fs(self, line: str, local_ns):
142+
import shlex
147143
args = shlex.split(line)
148144
if len(args) == 0:
149145
return
@@ -163,6 +159,7 @@ def fs(self, line: str, local_ns):
163159

164160

165161
def is_databricks_notebook(py_file: str):
162+
import os
166163
if os.path.exists(py_file):
167164
with open(py_file, "r") as f:
168165
return "Databricks notebook source" in f.readline()
@@ -175,6 +172,9 @@ def strip_hash_magic(lines: List[str]) -> List[str]:
175172
return lines
176173

177174
def convert_databricks_notebook_to_ipynb(py_file: str):
175+
import os
176+
import json
177+
178178
cells: List[dict[str, Any]] = [
179179
{
180180
"cell_type": "code",
@@ -205,10 +205,13 @@ def convert_databricks_notebook_to_ipynb(py_file: str):
205205
'nbformat_minor': 2
206206
})
207207

208-
208+
209+
from contextlib import contextmanager
209210
@contextmanager
210211
def databricks_notebook_exec_env(project_root: str, py_file: str):
212+
import os
211213
import sys
214+
import tempfile
212215
old_sys_path = sys.path
213216
old_cwd = os.getcwd()
214217

@@ -229,13 +232,107 @@ def databricks_notebook_exec_env(project_root: str, py_file: str):
229232
os.chdir(old_cwd)
230233

231234

232-
@logErrorAndContinue
235+
"""
236+
Splits an SQL string into individual statements using recursive descent parsing technique.
237+
Handles semicolons in strings and comments. Most probably breaks in dozens of other edge cases...
238+
"""
239+
class SqlStatementParser:
240+
def __init__(self, sql):
241+
self.sql = sql
242+
self.position = 0
243+
self.statements = []
244+
self.current = []
245+
246+
def parse(self):
247+
while self.position < len(self.sql):
248+
char = self.peek()
249+
next_char = self.peek_next()
250+
if char == '-' and next_char == '-':
251+
self.parse_line_comment()
252+
elif char == '/' and next_char == '*':
253+
self.parse_block_comment()
254+
elif char == "'":
255+
self.parse_string("'")
256+
elif char == '"':
257+
self.parse_string('"')
258+
elif char == '`':
259+
self.parse_string('`')
260+
elif char == ';':
261+
self.position += 1 # Skip the semicolon itself
262+
self.add_statement()
263+
else:
264+
self.consume()
265+
self.add_statement() # Add the last statement if there is one
266+
return self.statements
267+
268+
def peek(self):
269+
if self.position < len(self.sql):
270+
return self.sql[self.position]
271+
return None
272+
273+
def peek_next(self):
274+
if self.position + 1 < len(self.sql):
275+
return self.sql[self.position + 1]
276+
return None
277+
278+
def consume(self):
279+
char = self.peek()
280+
if char is not None:
281+
self.position += 1
282+
self.current.append(char)
283+
return char
284+
285+
def consume_next(self):
286+
char, next_char = self.peek(), self.peek_next()
287+
if char is not None and next_char is not None:
288+
self.position += 2
289+
self.current.extend([char, next_char])
290+
return char, next_char
291+
292+
def add_statement(self):
293+
if self.current:
294+
stmt = ''.join(self.current).strip()
295+
if stmt:
296+
self.statements.append(stmt)
297+
self.current = []
298+
299+
def parse_line_comment(self):
300+
self.consume_next() # Consume "--" that starts the comment
301+
while self.peek() is not None:
302+
if self.peek() == '\n':
303+
self.consume()
304+
return
305+
self.consume()
306+
307+
def parse_block_comment(self):
308+
self.consume_next() # Consume "/*" that starts the comment
309+
while self.peek() is not None:
310+
if self.peek() == '*' and self.peek_next() == '/':
311+
self.consume_next() # Consume "*/" that ends the comment
312+
return
313+
self.consume()
314+
315+
def parse_string(self, quote_char):
316+
self.consume() # Consume the opening quote
317+
while self.peek() is not None:
318+
# Handle escaped quote
319+
if self.peek() == '\\' and self.peek_next() == quote_char:
320+
self.consume_next() # Consume the escaped quote
321+
elif self.peek() == quote_char:
322+
self.consume() # Consume the closing quote
323+
return
324+
else:
325+
self.consume()
326+
233327
@disposable
234-
def register_magics(cfg: LocalDatabricksNotebookConfig):
328+
def create_databricks_magics_transformer(cfg: LocalDatabricksNotebookConfig):
329+
import os
330+
import warnings
331+
235332
def warn_for_dbr_alternative(magic: str):
236333
# Magics that are not supported on Databricks but work in jupyter notebooks.
237334
# We show a warning, prompting users to use a databricks equivalent instead.
238-
local_magic_dbr_alternative = {"%%sh": "%sh"}
335+
local_magic_dbr_alternative = {"%%sh": "sh"}
239336
if magic in local_magic_dbr_alternative:
240337
warnings.warn(
241338
"\n" + magic
@@ -247,7 +344,7 @@ def warn_for_dbr_alternative(magic: str):
247344

248345
def throw_if_not_supported(magic: str):
249346
# These are magics that are supported on dbr but not locally.
250-
unsupported_dbr_magics = ["%r", "%scala"]
347+
unsupported_dbr_magics = ["r", "scala"]
251348
if magic in unsupported_dbr_magics:
252349
raise NotImplementedError(
253350
magic
@@ -300,14 +397,14 @@ def handle(lines: List[str]):
300397

301398
if lmagic == "sql":
302399
lines = lines[1:]
303-
spark_string = (
304-
"global _sqldf\n"
305-
+ "_sqldf = spark.sql('''"
306-
+ "".join(lines).replace("'", "\\'")
307-
+ "''')\n"
308-
+ "_sqldf"
309-
)
310-
return spark_string.splitlines(keepends=True)
400+
sql_string = "".join(lines)
401+
statements = SqlStatementParser(sql_string).parse()
402+
result_code = ["global _sqldf\n"]
403+
for _, stmt in enumerate(statements):
404+
quoted_stmt = stmt.replace("'", "\\'")
405+
result_code.append(f"_sqldf = spark.sql('''{quoted_stmt}''')\n")
406+
result_code.append("_sqldf")
407+
return result_code
311408

312409
if lmagic == "python":
313410
return lines[1:]
@@ -317,8 +414,9 @@ def handle(lines: List[str]):
317414
if len(rest) == 0:
318415
return lines
319416

417+
raw_filename = rest[0]
320418
# Strip whitespace or possible quotes around the filename
321-
filename = rest[0].strip('\'" ')
419+
filename = raw_filename.strip('\'" ')
322420

323421
for suffix in ["", ".py", ".ipynb", ".ipy"]:
324422
if os.path.exists(os.path.join(os.getcwd(), filename + suffix)):
@@ -327,15 +425,14 @@ def handle(lines: List[str]):
327425

328426
return [
329427
f"with databricks_notebook_exec_env(r'{cfg.project_root}', r'{filename}') as file:\n",
330-
"\t%run -i {file} " + lines[0].partition('%run')[2].partition(filename)[2] + "\n"
428+
"\t%run -i {file} " + lines[0].partition('%run')[2].partition(raw_filename)[2].strip() + "\n"
331429
]
332430

333431
return lines
334432

335433
is_line_magic.handle = handle
336434
return get_line_magic(lines) is not None
337435

338-
339436
def parse_line_for_databricks_magics(lines: List[str]):
340437
if len(lines) == 0:
341438
return lines
@@ -350,10 +447,16 @@ def parse_line_for_databricks_magics(lines: List[str]):
350447
return magic_check.handle(lines)
351448

352449
return lines
450+
451+
return parse_line_for_databricks_magics
452+
353453

454+
@logErrorAndContinue
455+
@disposable
456+
def register_magics(cfg: LocalDatabricksNotebookConfig):
354457
ip = get_ipython()
355458
ip.register_magics(DatabricksMagics)
356-
ip.input_transformers_cleanup.append(parse_line_for_databricks_magics)
459+
ip.input_transformers_cleanup.append(create_databricks_magics_transformer(cfg))
357460

358461

359462
@logErrorAndContinue
@@ -372,6 +475,7 @@ def df_html(df):
372475
@logErrorAndContinue
373476
@disposable
374477
def register_spark_progress(spark, show_progress: bool):
478+
import time
375479
try:
376480
import ipywidgets as widgets
377481
except Exception as e:
@@ -389,7 +493,7 @@ def __init__(
389493
) -> None:
390494
self._ticks = None
391495
self._tick = None
392-
self._started = _time.time()
496+
self._started = time.time()
393497
self._bytes_read = 0
394498
self._running = 0
395499
self.init_ui()
@@ -430,7 +534,7 @@ def update_ticks(
430534
def output(self) -> None:
431535
if self._tick is not None and self._ticks is not None:
432536
percent_complete = (self._tick / self._ticks) * 100
433-
elapsed = int(_time.time() - self._started)
537+
elapsed = int(time.time() - self._started)
434538
scanned = self._bytes_to_string(self._bytes_read)
435539
running = self._running
436540
self.w_progress.value = percent_complete
@@ -475,6 +579,7 @@ def __call__(self,
475579
@logErrorAndContinue
476580
@disposable
477581
def update_sys_path(notebook_config: LocalDatabricksNotebookConfig):
582+
import sys
478583
sys.path.append(notebook_config.project_root)
479584

480585

@@ -486,14 +591,14 @@ def make_matplotlib_inline():
486591
except Exception as e:
487592
pass
488593

489-
490-
global _sqldf
491-
492-
try:
594+
@logErrorAndContinue
595+
@disposable
596+
def setup():
597+
import os
493598
import sys
494-
495599
print(sys.modules[__name__])
496600

601+
global _sqldf
497602
# Suppress grpc warnings coming from databricks-connect with newer version of grpcio lib
498603
os.environ["GRPC_VERBOSITY"] = "NONE"
499604

@@ -515,8 +620,9 @@ def make_matplotlib_inline():
515620

516621
for i in __disposables + ['__disposables']:
517622
globals().pop(i)
518-
globals().pop('i')
519623
globals().pop('disposable')
520624

521-
except Exception as e:
522-
logError("unknown", e)
625+
626+
import os
627+
if not os.environ.get("DATABRICKS_EXTENSION_UNIT_TESTS"):
628+
setup()

packages/databricks-vscode/resources/python/dbconnect-bootstrap.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ def load_env_from_leaf(path: str) -> bool:
3030

3131
# Suppress grpc warnings coming from databricks-connect with newer version of grpcio lib
3232
os.environ["GRPC_VERBOSITY"] = "NONE"
33+
# Ensures that dbutils.notebook.entry_point.getDbutils().notebook().getContext().notebookPath().get() returns the correct path
34+
os.environ["DATABRICKS_SOURCE_FILE"] = script
3335

3436
project_dir = load_env_from_leaf(cur_dir)
3537

0 commit comments

Comments
 (0)