Skip to content

Commit 003065b

Browse files
author
James Robinson
authored
Test and fix the bugfixes. (#86)
* Test and fix the bugfixes. * Better test over the input fed into databricks-connect, proving newlines and no preamble whitespace that a human would not have entered.
1 parent 89a5a9f commit 003065b

File tree

2 files changed

+133
-17
lines changed

2 files changed

+133
-17
lines changed

noteable_magics/datasource_postprocessing.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import shutil
33
from base64 import b64decode
44
from pathlib import Path
5-
from subprocess import PIPE, Popen
5+
from subprocess import PIPE, Popen, TimeoutExpired
66
from tempfile import NamedTemporaryFile
77
from typing import Any, Callable, Dict
88
from urllib.parse import quote_plus, urlparse
@@ -264,21 +264,27 @@ def postprocess_awsathena(
264264
create_engine_kwargs['s3_staging_dir'] = quote_plus(create_engine_kwargs['s3_staging_dir'])
265265

266266

267+
DATABRICKS_CONNECT_SCRIPT_TIMEOUT = 10 # seconds
268+
269+
267270
@register_postprocessor('databricks+connector')
268271
def postprocess_databricks(
269272
datasource_id: str, dsn_dict: Dict[str, str], create_engine_kwargs: Dict[str, Any]
270273
) -> None:
271274
"""ENG-5517: If cluser_id is present, and `databricks-connect` is in the path, then
272275
set up and run it.
273276
274-
Also be sure to purge cluster_id, org_id, port from create_engine_kwargs, in that these
275-
fields were added for only going into this side effect."""
277+
Also be sure to purge cluster_id, org_id, port from connect_args portion of create_engine_kwargs,
278+
in that these fields were added for only going into this side effect."""
276279

277280
cluster_id_key = 'cluster_id'
278281
connect_file_opt_keys = [cluster_id_key, 'org_id', 'port']
279282

280283
# Collect data to drive databricks-connect if we've got a cluster_id and script is in $PATH.
281284
connect_args = create_engine_kwargs['connect_args']
285+
# Only wanted for getting connect_args. Any additional dereferencing is a bug.
286+
del create_engine_kwargs
287+
282288
if cluster_id_key in connect_args and shutil.which('databricks-connect'):
283289
# host, token (actually, our password field) come from dsn_dict.
284290
# (and what databricks-connect wants as 'host' is actually a https:// URL. Sigh.)
@@ -297,16 +303,28 @@ def postprocess_databricks(
297303
connect_file_path.unlink()
298304

299305
p = Popen(['databricks-connect', 'configure'], stdout=PIPE, stdin=PIPE, stderr=PIPE)
300-
_stdout, stderr = p.communicate(input=f"""y
306+
try:
307+
_stdout, stderr = p.communicate(
308+
# Indention fugly so as to not prefix each input with whitespace.
309+
# And oh, be sure to have a newline betwen each input into the 'interactive' script.
310+
input=f"""y
301311
{args['host']}
302312
{args['token']}
303313
{args[cluster_id_key]}
304314
{args['org_id']}
305-
{args['port']}""".encode(), timeout=10)
315+
{args['port']}""".encode(),
316+
timeout=DATABRICKS_CONNECT_SCRIPT_TIMEOUT,
317+
)
318+
except TimeoutExpired:
319+
raise ValueError(
320+
f'databricks-connect took longer than {DATABRICKS_CONNECT_SCRIPT_TIMEOUT} seconds to complete.'
321+
)
306322

307323
if p.returncode != 0:
308324
# Failed to exectute the script. Raise an exception.
309-
raise ValueError("Failed to execute databricks-connect configure script: " + stderr)
325+
raise ValueError(
326+
"Failed to execute databricks-connect configure script: " + stderr.decode()
327+
)
310328

311329
# Always be sure to purge these only-for-databricks-connect file args from connect_args,
312330
# even if not all were present.

tests/test_datasources.py

Lines changed: 109 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
import os
55
from pathlib import Path
6-
from typing import Callable, List
6+
from typing import Callable, List, Tuple
77
from uuid import uuid4
88

99
import pkg_resources
@@ -554,6 +554,7 @@ class TestDatabricks:
554554

555555
@pytest.fixture()
556556
def tmp_home(self, tmpdir: Path) -> Path:
557+
"""Replace $HOME to be a new directory of $TMPDIR, yielding the new Path."""
557558
existing_home = os.environ['HOME']
558559

559560
new_home = tmpdir / 'home'
@@ -568,10 +569,13 @@ def tmp_home(self, tmpdir: Path) -> Path:
568569
os.environ['HOME'] = existing_home
569570

570571
@pytest.fixture()
571-
def databricks_connect_in_path(self, tmpdir: Path) -> Path:
572-
# Get a mock-ish executable 'databricks-connect' into an element in the path
573-
# so that which('databricks-connect') will find something (see databricks post
574-
# processor)
572+
def databricks_connect_in_path(self, tmpdir: Path) -> Tuple[Path, Path]:
573+
"""Get a mock-ish executable 'databricks-connect' into an element in the path
574+
so that which('databricks-connect') will find something (see databricks post
575+
processor)
576+
577+
Yields the new executable's path, plus where it will scribble its own output.
578+
"""
575579

576580
# Make a new subdir of tmpdir, add it to the path, create executable
577581
# shell script databricks-connect
@@ -593,15 +597,14 @@ def databricks_connect_in_path(self, tmpdir: Path) -> Path:
593597
scriptpath.chmod(0o755)
594598

595599
try:
596-
# Yield the script output path so a test can inspect its contents.
597-
yield script_output_path
600+
yield scriptpath, script_output_path
598601

599602
finally:
600603
# Undo $PATH change
601604
os.environ['PATH'] = orig_path
602605

603606
@pytest.fixture()
604-
def jsons_for_extra_behavior(self):
607+
def jsons_for_extra_behavior(self) -> Tuple[DatasourceJSONs, dict]:
605608
"""Return a DatasourceJSONs describing databricks that will tickle postprocess_databricks()
606609
into doing its extra behavior. Also returns dict of some of the fields within that JSON."""
607610

@@ -643,6 +646,100 @@ def jsons_for_extra_behavior(self):
643646
},
644647
)
645648

649+
def test_postprocess_databricks_pops_correctly(self, datasource_id, jsons_for_extra_behavior):
650+
"""Ensure that postprocess_databricks side effect pops from the correct dict (connect_args,
651+
not the containing create_engine_kwargs dict), even w/o databricks-connect
652+
being found in the $PATH.
653+
"""
654+
655+
keys_expected_to_be_removed = ['cluster_id', 'org_id', 'port']
656+
jsons_obj, specific_fields = jsons_for_extra_behavior
657+
connect_args = jsons_obj.connect_args_dict
658+
659+
# All initially there...
660+
assert all(key in connect_args for key in keys_expected_to_be_removed)
661+
662+
create_engine_kwargs = {'connect_args': connect_args}
663+
664+
datasource_postprocessing.postprocess_databricks(
665+
datasource_id,
666+
jsons_obj.dsn_dict,
667+
create_engine_kwargs,
668+
)
669+
670+
# Should have removed all the keys as side effect of the call.
671+
# (Had bug where they were popped from wrong dict originally.)
672+
assert not any(key in connect_args for key in keys_expected_to_be_removed)
673+
674+
def test_errors_from_databricks_connect_are_surfaced(
675+
datasource_id, databricks_connect_in_path, tmp_home, jsons_for_extra_behavior
676+
):
677+
"""Prove that if databricks-connect script exits nonzero, a ValueError is raised
678+
and the script's stderr will be within the error message."""
679+
680+
# Respell the databricks-connect script to always error out, expect that in a ValueError
681+
# when calling postprocess_databricks
682+
683+
script_path, _ = databricks_connect_in_path
684+
685+
expected_error_message = 'oh noes!'
686+
687+
# Respell the script to bomb out with message to stderr.
688+
with script_path.open('w') as of:
689+
of.write('#!/bin/sh\n')
690+
of.write(f'echo "{expected_error_message}" 1>&2\n')
691+
of.write('exit 1\n')
692+
693+
jsons_obj, specific_fields = jsons_for_extra_behavior
694+
create_engine_kwargs = {'connect_args': jsons_obj.connect_args_dict}
695+
696+
with pytest.raises(ValueError, match=expected_error_message):
697+
datasource_postprocessing.postprocess_databricks(
698+
datasource_id,
699+
jsons_obj.dsn_dict,
700+
create_engine_kwargs,
701+
)
702+
703+
@pytest.fixture()
704+
def short_script_timeout(self):
705+
"""Respell datasource_postprocessing.DATABRICKS_CONNECT_SCRIPT_TIMEOUT to 1 (second)"""
706+
original_value = datasource_postprocessing.DATABRICKS_CONNECT_SCRIPT_TIMEOUT
707+
708+
datasource_postprocessing.DATABRICKS_CONNECT_SCRIPT_TIMEOUT = 1
709+
710+
try:
711+
yield datasource_postprocessing.DATABRICKS_CONNECT_SCRIPT_TIMEOUT
712+
finally:
713+
datasource_postprocessing.DATABRICKS_CONNECT_SCRIPT_TIMEOUT = original_value
714+
715+
def test_databricks_connect_taking_too_long(
716+
datasource_id, databricks_connect_in_path, short_script_timeout, jsons_for_extra_behavior
717+
):
718+
"""Prove that if databricks-connect takes longer than allowed to run, that ValueError will
719+
be raised with an appropriate message.
720+
"""
721+
722+
# Respell the databricks-connect script to take longer than short_script_timeout seconds,
723+
# expect that in a ValueError when calling postprocess_databricks.
724+
725+
script_path, _ = databricks_connect_in_path
726+
727+
# Respell the script to take longer than new timeout, but to (try to) exit cleanly
728+
with script_path.open('w') as of:
729+
of.write('#!/bin/sh\n')
730+
of.write(f'sleep {short_script_timeout+1}\n')
731+
of.write('exit 0\n')
732+
733+
jsons_obj, specific_fields = jsons_for_extra_behavior
734+
create_engine_kwargs = {'connect_args': jsons_obj.connect_args_dict}
735+
736+
with pytest.raises(ValueError, match='databricks-connect took longer than'):
737+
datasource_postprocessing.postprocess_databricks(
738+
datasource_id,
739+
jsons_obj.dsn_dict,
740+
create_engine_kwargs,
741+
)
742+
646743
def test_extra_behavior(
647744
self, datasource_id, databricks_connect_in_path, tmp_home, jsons_for_extra_behavior
648745
):
@@ -673,13 +770,14 @@ def test_extra_behavior(
673770
# databricks_connect_in_path will create a different file.
674771
assert not dotconnect.exists()
675772

676-
# databricks_connect_in_path is the path where the fake script output was placed
677-
assert databricks_connect_in_path.exists()
773+
# databricks_connect_in_path second member is the path where the fake script output was placed
774+
_, script_output = databricks_connect_in_path
775+
assert script_output.exists()
678776

679777
# Expect to find things in it. See ENG-5517.
680778
# We can only test that we ran this mock script and the known result
681779
# of our mock script. What the real one does ... ?
682-
contents = databricks_connect_in_path.read().split()
780+
contents = script_output.read().split('\n')
683781
assert len(contents) == 6
684782
assert contents[0] == 'y'
685783
assert contents[1] == f"https://{case_dict['hostname']}/"

0 commit comments

Comments
 (0)