Skip to content

Commit 854a470

Browse files
Add a CredentialProvider reading from TaskContext in Spark Executors. (#232)
* Added a credential provider for accessing crdentials form the TaskContext on Spark Executors. * Added pylint directive to ignore broken import. * Adressed review comments. * Adressed review comments - added test making sure new provider does not break others if not available. * Adressed review comment. * Fixed linter complaints. * Addressed linter complaints. * Addressed linter complaints. * Addressed linter complaints. * Addressed linter complaints. * Added passing of ignoreTls flag to the executers. * Fixed typo. * nit * lint * lint * Renamed method. * set insecure classmethod -> staticmethod
1 parent 168609a commit 854a470

File tree

2 files changed

+73
-10
lines changed

2 files changed

+73
-10
lines changed

databricks_cli/configure/provider.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -185,19 +185,49 @@ def get_config(self):
185185

186186

187187
class DefaultConfigProvider(DatabricksConfigProvider):
188-
"""Prefers environment variables, and then the default profile."""
188+
"""Look for credentials in a chain of default locations."""
189189
def __init__(self):
190-
self.env_provider = EnvironmentVariableConfigProvider()
191-
self.default_profile_provider = ProfileConfigProvider()
190+
self._providers = (
191+
SparkTaskContextConfigProvider(),
192+
EnvironmentVariableConfigProvider(),
193+
ProfileConfigProvider()
194+
)
192195

193196
def get_config(self):
194-
env_config = self.env_provider.get_config()
195-
if env_config:
196-
return env_config
197+
for provider in self._providers:
198+
config = provider.get_config()
199+
if config is not None and config.is_valid:
200+
return config
201+
return None
202+
203+
204+
class SparkTaskContextConfigProvider(DatabricksConfigProvider):
205+
"""Loads credentials from Spark TaskContext if running in a Spark Executor."""
206+
207+
@staticmethod
208+
def _get_spark_task_context_or_none():
209+
try:
210+
from pyspark import TaskContext # pylint: disable=import-error
211+
return TaskContext.get()
212+
except ImportError:
213+
return None
197214

198-
profile_config = self.default_profile_provider.get_config()
199-
if profile_config:
200-
return profile_config
215+
@staticmethod
216+
def set_insecure(x):
217+
from pyspark import SparkContext # pylint: disable=import-error
218+
new_val = "True" if x else None
219+
SparkContext._active_spark_context.setLocalProperty("spark.databricks.ignoreTls", new_val)
220+
221+
def get_config(self):
222+
context = self._get_spark_task_context_or_none()
223+
if context is not None:
224+
host = context.getLocalProperty("spark.databricks.api.url")
225+
token = context.getLocalProperty("spark.databricks.token")
226+
insecure = context.getLocalProperty("spark.databricks.ignoreTls")
227+
config = DatabricksConfig.from_token(host=host, token=token, insecure=insecure)
228+
if config.is_valid:
229+
return config
230+
return None
201231

202232

203233
class EnvironmentVariableConfigProvider(DatabricksConfigProvider):

tests/configure/test_provider.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@
2828

2929
from databricks_cli.configure.provider import DatabricksConfig, DEFAULT_SECTION, \
3030
update_and_persist_config, get_config_for_profile, get_config, \
31-
set_config_provider, ProfileConfigProvider, _get_path, DatabricksConfigProvider
31+
set_config_provider, ProfileConfigProvider, _get_path, DatabricksConfigProvider,\
32+
SparkTaskContextConfigProvider
3233
from databricks_cli.utils import InvalidConfigurationError
3334

3435

@@ -133,6 +134,38 @@ def test_get_config_uses_default_profile():
133134
assert config.token == "hello"
134135

135136

137+
def test_get_config_uses_task_context_variable():
138+
class TaskContextMock(object):
139+
140+
def __init__(self):
141+
pass
142+
143+
def getLocalProperty(self, x): # NOQA
144+
if x == "spark.databricks.api.url":
145+
return "url"
146+
elif x == "spark.databricks.token":
147+
return "token"
148+
elif x == "spark.databricks.ignoreTls":
149+
return "True"
150+
else:
151+
raise Exception("should not get here.")
152+
153+
ctx_class = ("databricks_cli.configure.provider.SparkTaskContextConfigProvider."
154+
"_get_spark_task_context_or_none")
155+
with patch(ctx_class) as get_context_mock:
156+
get_context_mock.return_value = TaskContextMock()
157+
config = get_config()
158+
assert config.host == "url"
159+
assert config.token == "token"
160+
assert config.insecure == "True"
161+
assert config.username is None
162+
assert config.password is None
163+
164+
165+
def test_task_context_provider_does_not_break_stuff():
166+
assert SparkTaskContextConfigProvider().get_config() is None
167+
168+
136169
def test_get_config_uses_env_variable():
137170
with patch.dict('os.environ', {'DATABRICKS_HOST': TEST_HOST,
138171
'DATABRICKS_USERNAME': TEST_USER,

0 commit comments

Comments
 (0)